diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp deleted file mode 100644 index 2d5fd85827b2751085a78dcb241aa3cf081470d5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ /dev/null @@ -1,164 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. - This collective is now DEPRECATED, will be removed in the next release. Use EVT instead. -*/ - -#pragma once - -#include "sm90_epilogue_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int StagesC_, - int StagesD_, - int FragmentSize_, - class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) - class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) - class ElementC_, - class StrideC_, - class ElementD_, - class StrideD_, - class FusionCallbacks_, - class CopyOpG2S_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpS2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_, - class CopyAtomC_, - class CopyOpR2R_ -> -class Sm90EpilogueTmaWarpSpecializedBiasElementwise - : public CollectiveEpilogue< - Sm90TmaWarpSpecialized, - BlockTileShape_, - EpilogueTileShape_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyAtomC_, - CopyOpR2R_ -> { -private: - using Impl = - CollectiveEpilogue< - Sm90TmaWarpSpecialized, - BlockTileShape_, - EpilogueTileShape_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - FusionCallbacks_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_, - CopyAtomC_, - CopyOpR2R_ - >; -public: - using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; - using ElementCompute = typename Impl::ThreadEpilogueOp::ElementCompute; - using ElementBias = typename Impl::ThreadEpilogueOp::ElementBias; - using ElementT = typename Impl::ThreadEpilogueOp::ElementAux; - - // Constructor inheritance - using Impl::Impl; - - // Host side epilogue arguments - struct [[deprecated("use Sm90TmaWarpSpecialized Arguments instead")]] - Arguments { - struct ThreadArgs { - ElementCompute alpha{1}; - ElementCompute beta{0}; - ElementCompute const *alpha_ptr{nullptr}; - ElementCompute const *beta_ptr{nullptr}; - } thread; - ElementC_ const* ptr_C{nullptr}; - StrideC_ dC{}; - ElementD_* ptr_D{nullptr}; - StrideD_ dD{}; - ElementBias const* ptr_Bias{nullptr}; - ElementT* ptr_T{nullptr}; - - CUTLASS_HOST_DEVICE - operator typename Impl::Arguments() const { - typename Impl::Arguments arguments; - arguments.thread.alpha = thread.alpha; - arguments.thread.beta = thread.beta; - arguments.thread.alpha_ptr = thread.alpha_ptr; - arguments.thread.beta_ptr = thread.beta_ptr; - if constexpr (not cute::is_void_v) { - arguments.thread.bias_ptr = ptr_Bias; - } - if constexpr (not cute::is_void_v) { - arguments.thread.aux_ptr = ptr_T; - arguments.thread.dAux = dD; - } - arguments.ptr_C = ptr_C; - arguments.dC = dC; - arguments.ptr_D = ptr_D; - arguments.dD = dD; - - return arguments; - } - }; - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp deleted file mode 100644 index ca91ac19b0aadfeddcfb030ee16f03905855cd63..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp +++ /dev/null @@ -1,302 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/scale_type.h" - -////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue { - -////////////////////////////////////////////////////////////////////////////// - -////////////////////////////////////////////////////////////////////////////// -// -// Builder Epilogue Schedules -// -////////////////////////////////////////////////////////////////////////////// -// Pre-Hopper schedules -struct PtrArrayDefault {}; -struct EpilogueSimtVectorized {}; -struct EpiloguePtrArraySimtVectorized {}; -// Hopper direct store schedules -struct NoSmemWarpSpecialized {}; -struct PtrArrayNoSmemWarpSpecialized {}; -struct PtrArrayNoSmemWarpSpecializedTransposed {}; -// Hopper TMA schedules -struct TmaWarpSpecialized {}; -struct TmaWarpSpecializedCooperative {}; -struct PtrArrayTmaWarpSpecialized { static constexpr int NumEpilogueWarpGroups = 1; }; -struct PtrArrayTmaWarpSpecializedPingpong { static constexpr int NumEpilogueWarpGroups = 2; }; -struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; -// Blackwell direct store schedules -struct NoSmemWarpSpecialized1Sm {}; -struct NoSmemWarpSpecialized2Sm {}; -struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; -struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; -struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; -struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; -struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; -struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; -struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; -struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; -struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; -struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; -// Blackwell TMA schedules -struct TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2Sm {}; -struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; -struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {}; -struct TmaWarpSpecialized1SmNvf4 final : TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2SmNvf4 final : TmaWarpSpecialized2Sm {}; -struct TmaWarpSpecialized1SmMxf4 final : TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2SmMxf4 final : TmaWarpSpecialized2Sm {}; -struct TmaWarpSpecialized1SmMxf8f6f4 final : TmaWarpSpecialized1Sm {}; -struct TmaWarpSpecialized2SmMxf8f6f4 final : TmaWarpSpecialized2Sm {}; -// Cooperative epilogue schedule for sm120 sparse kernels -struct SparseTmaWarpSpecializedCooperativeSm120 : public TmaWarpSpecializedCooperative {}; - -// DEPRECATED schedules, will be removed in next release -struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; -struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; -template < - template class ActivationFunctor_, - thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, - FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest -> -struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] -TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - static constexpr thread::ScaleType::Kind Scale = Scale_; - static constexpr FloatRoundStyle Round = Round_; -}; - -template < - template class ActivationFunctor_, - thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, - FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest -> -struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombEltAct instead")]] -TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - static constexpr thread::ScaleType::Kind Scale = Scale_; - static constexpr FloatRoundStyle Round = Round_; -}; - -struct TmaWarpSpecializedBiasElementwiseBase : public TmaWarpSpecialized{}; -struct TmaWarpSpecializedCooperativeBiasElementwiseBase : public TmaWarpSpecializedCooperative {}; - -template < - template class ActivationFunctor_, - class ElementT_, - template class BiasOp_, - bool StoreT_, - class ElementBias_ -> -struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltActAux instead")]] -TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - using ElementT = ElementT_; - - template - using BiasOp = BiasOp_; - - static constexpr bool StoreT = StoreT_; - using ElementBias = ElementBias_; -}; - -template < - template class ActivationFunctor_, - class ElementT_, - template class BiasOp_, - bool StoreT_, - class ElementBias_ -> -struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombPerRowBiasEltActAux instead")]] -TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { - template - using ActivationFunctor = ActivationFunctor_; - - using ElementT = ElementT_; - - template - using BiasOp = BiasOp_; - - static constexpr bool StoreT = StoreT_; - using ElementBias = ElementBias_; -}; - -////////////////////////////////////////////////////////////////////////////// -// -// Collective Dispatch Policies -// -////////////////////////////////////////////////////////////////////////////// - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm90TmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; -}; - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - int NumEpilogueWarpGroups_ -> -struct Sm90PtrArrayTmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; -}; - -// DEPRECATED policies, will be removed in next release -template< - int StagesC_, - int StagesD_, - int FragmentSize_ = 2 -> -struct Sm90TmaWarpSpecializedBiasElementwise { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; -}; - - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm100TmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; -}; - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm100PtrArrayTmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - - static_assert(StagesC >= 1, "StagesC must be >= 1"); - static_assert(StagesD >= 1, "StagesD must be >= 1"); -}; - -struct Sm100NoSmem { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; - -struct Sm100NoSmemWarpSpecialized { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; - -struct Sm100PtrArrayNoSmem { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; - -struct Sm100PtrArrayNoSmemWarpSpecialized { - constexpr static int StagesC = 1; - constexpr static int StagesD = 1; - constexpr static int FragmentSize = 1; -}; -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_ -> -struct Sm120TmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; -}; - -template< - int StagesC_, - int StagesD_, - int FragmentSize_, - bool ReuseSmemC_, - bool DelayTmaStore_, - int NumEpilogueWarpGroups_ -> -struct Sm120PtrArrayTmaWarpSpecialized { - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static int FragmentSize = FragmentSize_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool DelayTmaStore = DelayTmaStore_; - constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; -}; - -////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp deleted file mode 100644 index f9febeec4d92d54ec02e221d028f7329c2edeea5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp +++ /dev/null @@ -1,91 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/detail/dependent_false.hpp" -#include "cutlass/epilogue/fusion/operations.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Dispatch interface for epilogue fusion callbacks -// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. -// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, -// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. -template < - class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm - class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination - class CtaTile_MNK, // computed tile per CTA - class EpilogueTile_MN, // epilogue subtile size - class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) -> -struct FusionCallbacks { - static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); -}; - -// Metadata helper to handle custom EVTs or other non-FusionCallbacks types -template -struct FusionCallbacksTraits { - using DispatchPolicy = void; - using Callbacks = T; - using Operation = FusionOperation; - using CtaTile_MNK = void; - using EpilogueTile_MN = void; - using ElementCompute = void; -}; - -template < - class DispatchPolicy_, - class Operation_, - class CtaTile_MNK_, - class EpilogueTile_MN_, - class... Args -> -struct FusionCallbacksTraits< - FusionCallbacks -> { - using DispatchPolicy = DispatchPolicy_; - using Callbacks = FusionCallbacks; - using Operation = Operation_; - using CtaTile_MNK = CtaTile_MNK_; - using EpilogueTile_MN = EpilogueTile_MN_; - using ElementCompute = typename Operation::ElementCompute; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp deleted file mode 100644 index 114737a9d910a458f4895212d0904e002a9aeec8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp +++ /dev/null @@ -1,645 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include -#include -#include -#include // cute::false_type - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Fusion Operations -// Template args must not be implementation dependent -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct FusionOperation { - // metadata types/queries that can be overrided - using ElementOutput = void; - using ElementCompute = void; - FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate; - - using ElementSource = void; - static constexpr bool IsSourceSupported = false; - static constexpr bool IsResidualSupported = false; // Source is added after activation - - using ElementScalar = void; - static constexpr int AlignmentScalar = 0; - static constexpr bool IsScaleFactorSupported = false; - static constexpr bool IsPerRowScaleSupported = false; - static constexpr bool IsPerColScaleSupported = false; - - using ElementBias = void; - static constexpr int AlignmentBias = 0; - static constexpr bool IsPerRowBiasSupported = false; - static constexpr bool IsPerColBiasSupported = false; - static constexpr bool IsDePerRowBiasSupported = false; - - using ActivationFn = void; - static constexpr bool IsEltActSupported = false; - static constexpr bool IsDeEltActSupported = false; - - using ElementAux = void; - using GmemLayoutTagAux = void; - static constexpr int AlignmentAux = 0; - static constexpr bool IsAuxOutSupported = false; - static constexpr bool IsAuxInSupported = false; - - using ElementAmax = void; - static constexpr bool IsAbsMaxSupported = false; - - using ElementBlockScaleFactor = void; - static constexpr int SFVecSize = 0; - static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues - using GmemLayoutTagScalefactor = void; -}; - -// D = alpha * acc -template< - class ElementOutput_, - class ElementCompute_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledAcc : FusionOperation { - using ElementOutput = ElementOutput_; - using ElementCompute = ElementCompute_; - using ElementScalar = ElementScalar_; - static constexpr int AlignmentScalar = 1; - static constexpr auto RoundStyle = RoundStyle_; -}; - -// D = alpha * acc + beta * C -template< - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinearCombination - : ScaledAcc { - using ElementSource = ElementSource_; - static constexpr bool IsSourceSupported = true; -}; - -// D = activation(alpha * acc + beta * C) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombEltAct - : LinearCombination { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; -}; - -// D = softmax(top_k(alpha * acc + beta * C)) -template< - int TopK, - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombTopKSoftmaxCol - : LinearCombination { -}; - - -// D = alpha * acc + beta * C + per-row bias -template< - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBias - : LinearCombination { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsPerRowBiasSupported = true; -}; - -// D = alpha * acc + beta * C + per-column bias -template< - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBias - : LinearCombination { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsPerColBiasSupported = true; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasEltAct - : LinCombPerRowBias { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; -}; - -// Grouped Wgrad's D = alpha * acc + beta * C with special AccFetch. -template< - class GroupsPerTile_, - class ElementOutput_, - class ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinearCombinationGroupedWgrad - : LinearCombination { - using GroupsPerTile = GroupsPerTile_; -}; - -// D = activation(alpha * acc + beta * C + per-column bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasEltAct - : LinCombPerColBias { - using ActivationFn = ActivationFn_; - static constexpr bool IsEltActSupported = true; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -// aux = alpha * acc + beta * C + per-row bias -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasEltActAux - : LinCombPerRowBiasEltAct { - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// D = activation(alpha * acc + beta * C + per-col bias) -// aux = alpha * acc + beta * C + per-col bias -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasEltActAux - : LinCombPerColBiasEltAct { - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - int AlignmentScalar_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct PerRowLinCombPerRowBiasEltAct - : LinCombPerRowBiasEltAct { - static constexpr int AlignmentScalar = AlignmentScalar_; - static constexpr bool IsPerRowScaleSupported = true; -}; - -// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - int AlignmentScalar_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct PerColLinCombPerColBiasEltAct - : LinCombPerColBiasEltAct { - static constexpr int AlignmentScalar = AlignmentScalar_; - static constexpr bool IsPerColScaleSupported = true; -}; - -// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, // per-row alpha/beta - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - int AlignmentScalar_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct PerColResAddPerColBiasEltAct - : PerColLinCombPerColBiasEltAct { - static constexpr bool IsResidualSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerRowBiasEltAct - : LinCombPerRowBiasEltAct { - static constexpr bool IsScaleFactorSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerColBiasEltAct - : LinCombPerColBiasEltAct { - static constexpr bool IsScaleFactorSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementAmax_ = ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerRowBiasEltActAmaxAux - : ScaledLinCombPerRowBiasEltAct { - using ElementAmax = ElementAmax_; - static constexpr bool IsAbsMaxSupported = true; - - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementAmax_ = ElementCompute_, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct ScaledLinCombPerColBiasEltActAmaxAux - : ScaledLinCombPerColBiasEltAct { - using ElementAmax = ElementAmax_; - static constexpr bool IsAbsMaxSupported = true; - - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxOutSupported = true; -}; - -// Z = Aux -// dY = alpha * acc + beta * C -// D = d_activation(dY, Z) -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombDeEltAct - : LinearCombination { - using ActivationFn = ActivationFn_; - static constexpr bool IsDeEltActSupported = true; - - using ElementAux = ElementAux_; - using GmemLayoutTagAux = GmemLayoutTagAux_; - static constexpr int AlignmentAux = AlignmentAux_; - static constexpr bool IsAuxInSupported = true; -}; - -// Z = Aux -// dY = alpha * acc + beta * C -// D = d_activation(dY, Z) -// dBias = sum of columns of D -template< - class GmemLayoutTagAux_, - template class ActivationFn_, - class ElementOutput_, - class ElementCompute_, - class ElementAux_ = ElementOutput_, - class ElementBias_ = ElementCompute_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentAux_ = 128 / cute::sizeof_bits_v, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombDeEltActDePerRowBias - : LinCombDeEltAct { - using ElementBias = ElementBias_; - static constexpr int AlignmentBias = AlignmentBias_; - static constexpr bool IsDePerRowBiasSupported = true; -}; - -template< - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombBlockScaleFactor - : LinearCombination { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - -// D = activation(alpha * acc + beta * C) -// With BlockScaleFactor generation (same recipe as LinCombBlockScaleFactor). -template< - template class ActivationFn_, - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombEltActBlockScaleFactor - : LinCombEltAct { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - -// D = alpha * acc + beta * C + per-row bias -// With BlockScaleFactor generation -template< - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasBlockScaleFactor - : LinCombPerRowBias { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -// D = alpha * acc + beta * C + per-col bias -// With BlockScaleFactor generation. -template< - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasBlockScaleFactor - : LinCombPerColBias { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -// D = activation(alpha * acc + beta * C + per-row bias) -// With BlockScaleFactor generation. -template< - template class ActivationFn_, - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerRowBiasEltActBlockScaleFactor - : LinCombPerRowBiasEltAct { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -// D = activation(alpha * acc + beta * C + per-col bias) -// With BlockScaleFactor generation. -template< - template class ActivationFn_, - int SFVecSize_, - class ElementOutput_, - class ElementCompute_, - class ElementBlockScaleFactor_, - class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, - class ElementBias_ = ElementOutput_, - class ElementSource_ = ElementOutput_, - class ElementScalar_ = ElementCompute_, - int AlignmentBias_ = 128 / cute::sizeof_bits_v, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest -> -struct LinCombPerColBiasEltActBlockScaleFactor - : LinCombPerColBiasEltAct { - using ElementBlockScaleFactor = ElementBlockScaleFactor_; - static constexpr int SFVecSize = SFVecSize_; - static constexpr bool IsBlockScaleSupported = true; - using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp deleted file mode 100644 index dfbb75bf00bd2160af770566c4f3970a2c7b5b10..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp +++ /dev/null @@ -1,1322 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Fusion callbacks specializations for the sm100 TMA warp-specialized (ws) epilogue -*/ - - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" - -#include "cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Sm100 Tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// Sm100 direct store callbacks alias to sm100 tma callbacks with 0 stages -// Additional copy atom args will be ignored in the 0-stage specializations of aux load/store nodes -template < - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100NoSmemWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// Sm100 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C -// With Row BlockScaleFactor Generation. -template< - int SFVecsize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinearCombRowBlockScaleFactor = - Sm90EVT, // gen scalefactor - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C -// With Col BlockScaleFactor Generation. -template< - int SFVecsize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinearCombColBlockScaleFactor = - Sm90EVT, // gen scalefactor - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinearCombColBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinearCombColBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// For Ptr-Array and Grouped GEMM -// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinearCombRowBlockScaleFactorPtrArray = - Sm90EVT, // gen scalefactor - Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - // NormConst is a single device-side constant value, its not per-batch or per-group - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// For Ptr-Array and Grouped GEMM -// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombEltActRowBlockScaleFactorPtrArray = - Sm90EVT, // gen scalefactor - Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized, - fusion::LinCombEltActBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltActBlockScaleFactor; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per-row bias -// with row blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, - ElementScalar, - AlignmentBias, - RoundStyle - > -{ - - using Impl = - Sm100LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// D = alpha * acc + beta * C + per-row bias -// with col blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasColBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorColStore< - SFVecsize, EpilogueTile, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm100LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per_col bias -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerColBiasRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm100LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -// with row blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, - ElementOutput, ElementCompute, - ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -// with col blockScaled generation -template< - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerRowBiasEltActColBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorColStore< - SFVecsize, EpilogueTile, - ElementOutput, ElementCompute, - ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerRowBiasEltActColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm100LinCombPerRowBiasEltActColBlockScaleFactor< - SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per_col bias) -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm100LinCombPerColBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm100BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, - ElementOutput, ElementCompute, - ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm100LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm100LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - - -// -------------------------------------------------------------------- -// Sm100PtrArrayNoSmemWarpSpecialized (direct-store, grouped GEMM) -// -------------------------------------------------------------------- -template < - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm100PtrArrayNoSmemWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...> - : FusionCallbacks< - // reuse the ptr-array *TMA* callbacks with 0 stages - epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...> { - - using Base = FusionCallbacks< - epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>; - - // bring ctors into scope - using Base::Base; -}; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp deleted file mode 100644 index a20591288ad386543c3c7f0fd399c7fe45b7f60a..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp +++ /dev/null @@ -1,500 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree compute operations for the sm100 TMA warp-specialized (ws) epilogue -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cutlass/epilogue/thread/activation.h" -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// BatchNormApply -// -// This node aims to do the batch norm apply. The procedure is described as follows: -// -// output = (input - mean) * inv_stddev * alpha + bias -// -// while: (1) input & output are 2 matrices with shape (M, N), -// which are frg_input & return value of the visit function -// -// (2) mean, inv_stddev, alpha & bias are 4 vectors with shape (N). -// which are loaded by ProducerLoadCallbacks -// -// To avoid redundant calculations in EVT, this node simplify the procedure as follows: -// -// output = input * alpha' + bias' -// -// while alpha' & bias' are 2 vectors with shape (N) calculated by mean, inv_stddev, alpha & bias -// -// The calculation among vectors is described as follows: -// -// alpha' = alpha * inv_stddev -// bias' = bias - mean * alpha' -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - // reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // this should just match CLC stage count - int Stages, - class CtaTileShapeMNK, - class ElementScalar, - class ElementCompute, - class ElementOutput, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm100BatchNormApply { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert(cute::is_same_v>); // row vector broadcast for alpha, bias, mean & inv_stddev - - using SmemLayout = decltype(make_layout(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})))); - - using ElementCol = cute::conditional_t<(sizeof(ElementCompute) > sizeof(ElementScalar)), ElementCompute, ElementScalar>; - - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_alpha; - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_bias; - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_mean; - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_inv_stddev; - }; - - struct Arguments { - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* bias_ptr = nullptr; - ElementScalar const* mean_ptr = nullptr; - ElementScalar const* inv_stddev_ptr = nullptr; - StrideMNL dVec = {}; - }; - - struct Params { - using TMA_Vec = decltype(make_tma_atom( - SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), - take<0,2>(SmemLayout{}), - take<0,2>(CtaTileShapeMNK{}))); - - TMA_Vec tma_load_alpha; - TMA_Vec tma_load_bias; - TMA_Vec tma_load_mean; - TMA_Vec tma_load_inv_stddev; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - - Tensor tensor_alpha = make_tensor(make_gmem_ptr(args.alpha_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - Tensor tensor_bias = make_tensor(make_gmem_ptr(args.bias_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - Tensor tensor_mean = make_tensor(make_gmem_ptr(args.mean_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - Tensor tensor_inv_stddev = make_tensor(make_gmem_ptr(args.inv_stddev_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); - - typename Params::TMA_Vec tma_load_alpha = make_tma_atom(SM90_TMA_LOAD{}, tensor_alpha, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - typename Params::TMA_Vec tma_load_bias = make_tma_atom(SM90_TMA_LOAD{}, tensor_bias, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - typename Params::TMA_Vec tma_load_mean = make_tma_atom(SM90_TMA_LOAD{}, tensor_mean, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - typename Params::TMA_Vec tma_load_inv_stddev = make_tma_atom(SM90_TMA_LOAD{}, tensor_inv_stddev, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); - - return Params{tma_load_alpha, tma_load_bias, tma_load_mean, tma_load_inv_stddev}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm100BatchNormApply() { } - - CUTLASS_HOST_DEVICE - Sm100BatchNormApply(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms), - smem_alpha(const_cast(shared_storage.smem_alpha.data())), - smem_bias(const_cast(shared_storage.smem_bias.data())), - smem_mean(const_cast(shared_storage.smem_mean.data())), - smem_inv_stddev(const_cast(shared_storage.smem_inv_stddev.data())), - smem_col_alpha(const_cast(shared_storage.smem_alpha.data())), - smem_col_bias(const_cast(shared_storage.smem_bias.data())) { } - - Params const* params_ptr; - ElementScalar* smem_alpha; - ElementScalar* smem_bias; - ElementScalar* smem_mean; - ElementScalar* smem_inv_stddev; - ElementCompute* smem_col_alpha; - ElementCompute* smem_col_bias; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return true; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gAlpha, GTensor&& gBias, GTensor&& gMean, GTensor&& gInvStddev, - STensor&& sAlpha, STensor&& sBias, STensor&& sMean, STensor&& sInvStddev, Params const* params_ptr) - : gAlpha(cute::forward(gAlpha)), - gBias(cute::forward(gBias)), - gMean(cute::forward(gMean)), - gInvStddev(cute::forward(gInvStddev)), - sAlpha(cute::forward(sAlpha)), - sBias(cute::forward(sBias)), - sMean(cute::forward(sMean)), - sInvStddev(cute::forward(sInvStddev)), - params_ptr(params_ptr) {} - - GTensor gAlpha; - GTensor gBias; - GTensor gMean; - GTensor gInvStddev; - - STensor sAlpha; - STensor sBias; - STensor sMean; - STensor sInvStddev; - - Params const* params_ptr; - - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - if (epi_m == 0 && epi_n == 0 && issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * bits_to_bytes(sizeof_bits_v) * 4; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - int pipe_index = (load_iteration / EpiTiles) % Stages; - copy(params_ptr->tma_load_alpha.with(*full_mbarrier_ptr), gAlpha, sAlpha(_,pipe_index)); - copy(params_ptr->tma_load_bias.with(*full_mbarrier_ptr), gBias, sBias(_,pipe_index)); - copy(params_ptr->tma_load_mean.with(*full_mbarrier_ptr), gMean, sMean(_,pipe_index)); - copy(params_ptr->tma_load_inv_stddev.with(*full_mbarrier_ptr), gInvStddev, sInvStddev(_,pipe_index)); - } - } - }; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - Tensor mAlpha = params_ptr->tma_load_alpha.get_tma_tensor(make_shape(size(M),N,size(L))); - Tensor mBias = params_ptr->tma_load_bias.get_tma_tensor(make_shape(size(M),N,size(L))); - Tensor mMean = params_ptr->tma_load_mean.get_tma_tensor(make_shape(size(M),N,size(L))); - Tensor mInvStddev = params_ptr->tma_load_inv_stddev.get_tma_tensor(make_shape(size(M),N,size(L))); - - Tensor gAlpha = local_tile(mAlpha, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gBias = local_tile(mBias, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gMean = local_tile(mMean, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gInvStddev = local_tile(mInvStddev, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - Tensor sBias = make_tensor(make_smem_ptr(smem_bias), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - Tensor sMean = make_tensor(make_smem_ptr(smem_mean), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), SmemLayout{}); // (CTA_M,CTA_N,PIPE) - - auto [tCgAlpha, tCsAlpha] = tma_partition(params_ptr->tma_load_alpha, group_modes<0,2>(sAlpha), group_modes<0,2>(gAlpha)); - auto [tCgBias, tCsBias] = tma_partition(params_ptr->tma_load_bias, group_modes<0,2>(sBias), group_modes<0,2>(gBias)); - auto [tCgMean, tCsMean] = tma_partition(params_ptr->tma_load_mean, group_modes<0,2>(sMean), group_modes<0,2>(gMean)); - auto [tCgInvStddev, tCsInvStddev] = tma_partition(params_ptr->tma_load_inv_stddev, group_modes<0,2>(sInvStddev), group_modes<0,2>(gInvStddev)); - - constexpr int EpiTiles = decltype(size(ceil_div(shape(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(tCgAlpha), cute::move(tCgBias), cute::move(tCgMean), cute::move(tCgInvStddev), - cute::move(tCsAlpha), cute::move(tCsBias), cute::move(tCsMean), cute::move(tCsInvStddev), params_ptr); - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - SR_RTensor&& tSR_rAlpha, SR_RTensor&& tSR_rBias, - SR_RTensor&& tSR_rMean, SR_RTensor&& tSR_rInvStddev, - SR_STensor&& tSR_sAlpha, SR_STensor&& tSR_sBias, - SR_STensor&& tSR_sMean, SR_STensor&& tSR_sInvStddev, - SR_CTensor&& tSR_cAlpha, - SR_SCTensor&& tSR_sColAlpha, SR_SCTensor&& tSR_sColBias, - RTensor&& tCrAlpha, RTensor&& tCrBias, - STensor&& tCsAlpha, STensor&& tCsBias, - ThrNum thr_num, - Params const* params_ptr) - : - tSR_rAlpha(cute::forward(tSR_rAlpha)), tSR_rBias(cute::forward(tSR_rBias)), - tSR_rMean(cute::forward(tSR_rMean)), tSR_rInvStddev(cute::forward(tSR_rInvStddev)), - tSR_sAlpha(cute::forward(tSR_sAlpha)), tSR_sBias(cute::forward(tSR_sBias)), - tSR_sMean(cute::forward(tSR_sMean)), tSR_sInvStddev(cute::forward(tSR_sInvStddev)), - tSR_cAlpha(cute::forward(tSR_cAlpha)), - tSR_sColAlpha(cute::forward(tSR_sColAlpha)), tSR_sColBias(cute::forward(tSR_sColBias)), - tCrAlpha(cute::forward(tCrAlpha)), tCrBias(cute::forward(tCrBias)), - tCsAlpha(cute::forward(tCsAlpha)), tCsBias(cute::forward(tCsBias)), - thr_num(thr_num), - params_ptr(params_ptr) {} - - SR_RTensor tSR_rAlpha; - SR_RTensor tSR_rBias; - SR_RTensor tSR_rMean; - SR_RTensor tSR_rInvStddev; - SR_STensor tSR_sAlpha; - SR_STensor tSR_sBias; - SR_STensor tSR_sMean; - SR_STensor tSR_sInvStddev; - SR_CTensor tSR_cAlpha; - SR_SCTensor tSR_sColAlpha; - SR_SCTensor tSR_sColBias; - - ThrNum thr_num; - - RTensor tCrAlpha; // (CPY,CPY_M,CPY_N) - RTensor tCrBias; // (CPY,CPY_M,CPY_N) - - STensor tCsAlpha; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - STensor tCsBias; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - - Params const* params_ptr; - - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if (epi_m == 0 && epi_n == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - int pipe_index = (load_iteration / EpiTiles) % Stages; - - Tensor tSR_rAlpha_flt = filter_zeros(tSR_rAlpha); - Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); - Tensor tSR_rMean_flt = filter_zeros(tSR_rMean); - Tensor tSR_rInvStddev_flt = filter_zeros(tSR_rInvStddev); - Tensor tSR_sAlpha_flt = filter_zeros(tSR_sAlpha(_,_,_,pipe_index)); - Tensor tSR_sBias_flt = filter_zeros(tSR_sBias(_,_,_,pipe_index)); - Tensor tSR_sMean_flt = filter_zeros(tSR_sMean(_,_,_,pipe_index)); - Tensor tSR_sInvStddev_flt = filter_zeros(tSR_sInvStddev(_,_,_,pipe_index)); - Tensor tSR_cAlpha_flt = filter_zeros(tSR_cAlpha, tSR_rAlpha.stride()); - - for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { - if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - // OOB of SMEM - continue; - } - tSR_rAlpha_flt(i) = tSR_sAlpha_flt(i); - tSR_rBias_flt(i) = tSR_sBias_flt(i); - tSR_rMean_flt(i) = tSR_sMean_flt(i); - tSR_rInvStddev_flt(i) = tSR_sInvStddev_flt(i); - } - - constexpr int RegFragSize = cute::min(size(tSR_rAlpha_flt), cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute)))); - Tensor tSR_rAlpha_frg = recast>(tSR_rAlpha_flt); // (FRG_V) - Tensor tSR_rBias_frg = recast>(tSR_rBias_flt); // (FRG_V) - Tensor tSR_rMean_frg = recast>(tSR_rMean_flt); // (FRG_V) - Tensor tSR_rInvStddev_frg = recast>(tSR_rInvStddev_flt); // (FRG_V) - - cutlass::multiplies> mul; - cutlass::negate> negate; - cutlass::multiply_add> mul_add; - - // We do computation among vectors before computation among matrices - // alpha' = alpha * inv_stddev - // bias' = bias - alpha' * mean - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tSR_rAlpha_frg); ++i) { - tSR_rAlpha_frg(i) = mul(tSR_rAlpha_frg(i), tSR_rInvStddev_frg(i)); - tSR_rBias_frg(i) = mul_add(tSR_rAlpha_frg(i), negate(tSR_rMean_frg(i)), tSR_rBias_frg(i)); - } - - Tensor tSR_sColAlpha_flt = filter_zeros(tSR_sColAlpha(_,_,_,pipe_index)); - Tensor tSR_sColBias_flt = filter_zeros(tSR_sColBias(_,_,_,pipe_index)); - // After computation, 4 vectors -> 2 vectors - for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { - if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - // OOB of SMEM - continue; - } - tSR_sColAlpha_flt(i) = tSR_rAlpha_flt(i); - tSR_sColBias_flt(i) = tSR_rBias_flt(i); - } - - synchronize(); - - // To do bn_apply with Acc, reload these 2 vectors with the consistent shape - copy_aligned(tCsAlpha(_,_,_,_,_,pipe_index), tCrAlpha); - copy_aligned(tCsBias(_,_,_,_,_,pipe_index), tCrBias); - } - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_inputs) { - constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); - cutlass::multiply_add> mul_add; - - Array frg_apply; - - using ConvertInput = NumericArrayConverter; - using ConvertOutput = NumericArrayConverter; - - ConvertInput convert_input{}; - ConvertOutput convert_output{}; - - Array frg_I = convert_input(frg_inputs); - - Tensor tCrAlpha_frg = recast>(tCrAlpha(_,_,_,epi_m,epi_n)); - Tensor tCrBias_frg = recast>(tCrBias(_,_,_,epi_m,epi_n)); - - constexpr int RegFragArraySize = FragmentSize / RegFragSize; - using RegFragArr = Array, RegFragArraySize>; - RegFragArr& frg_I_ = reinterpret_cast(frg_I); - RegFragArr& frg_apply_ = reinterpret_cast(frg_apply); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < RegFragArraySize; ++i) { - frg_apply_[i] = mul_add(tCrAlpha_frg(epi_v * RegFragArraySize + i), frg_I_[i], tCrBias_frg(epi_v * RegFragArraySize + i)); - } - - return convert_output(frg_apply); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - using ThreadCount = decltype(size(args.tiled_copy)); - - Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sBias = make_tensor(make_smem_ptr(smem_bias), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sColAlpha = make_tensor(make_smem_ptr(smem_col_alpha), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sColBias = make_tensor(make_smem_ptr(smem_col_bias), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sMean = make_tensor(make_smem_ptr(smem_mean), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - // S2R: Smem to Reg - auto tiled_s2r = make_tiled_copy(Copy_Atom{}, - Layout< Shape<_1, ThreadCount>, - Stride<_0, _1>>{}, - Layout<_1>{}); - auto thr_s2r = tiled_s2r.get_slice(args.thread_idx); - Tensor tSR_sAlpha = thr_s2r.partition_S(sAlpha); - Tensor tSR_sBias = thr_s2r.partition_S(sBias); - Tensor tSR_sMean = thr_s2r.partition_S(sMean); - Tensor tSR_sInvStddev = thr_s2r.partition_S(sInvStddev); - Tensor tSR_sColAlpha = thr_s2r.partition_S(sColAlpha); - Tensor tSR_sColBias = thr_s2r.partition_S(sColBias); - Tensor tSR_cAlpha = thr_s2r.partition_S(args.cD); - - Tensor tSR_rAlpha = make_tensor_like(take<0,3>(tSR_sAlpha)); // need to check - Tensor tSR_rBias = make_tensor_like(take<0,3>(tSR_sBias)); - Tensor tSR_rMean = make_tensor_like(take<0,3>(tSR_sMean)); - Tensor tSR_rInvStddev = make_tensor_like(take<0,3>(tSR_sInvStddev)); - - Tensor tCsAlpha = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sColAlpha, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCsBias = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sColBias, args.epi_tile, args.tiled_copy, args.thread_idx); - - Tensor tCrAlpha = make_tensor_like(take<0,5>(tCsAlpha)); // (CPY,CPY_M,CPY_N) - Tensor tCrBias = make_tensor_like(take<0,5>(tCsBias)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tSR_rAlpha), cute::move(tSR_rBias), - cute::move(tSR_rMean), cute::move(tSR_rInvStddev), - cute::move(tSR_sAlpha), cute::move(tSR_sBias), - cute::move(tSR_sMean), cute::move(tSR_sInvStddev), - cute::move(tSR_cAlpha), - cute::move(tSR_sColAlpha), cute::move(tSR_sColBias), - cute::move(tCrAlpha), cute::move(tCrBias), - cute::move(tCsAlpha), cute::move(tCsBias), - ThreadCount{}, - params_ptr); - } -}; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp deleted file mode 100644 index d026b15ccacef0bb199b7a98172c722f9402d075..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp +++ /dev/null @@ -1,666 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree store operations for the sm100 TMA warp-specialized (ws) epilogue -*/ - - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/detail/helper_macros.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -namespace detail { - template - CUTLASS_DEVICE auto - compute_quantized_with_row_scalefactor( - Array& frg_compute, - Array& frg_sf, - ElementCompute norm_constant) - { - cutlass::multiplies mul; - cutlass::multiplies> mul_array; - - Array frg_output; - auto output_frgs = reinterpret_cast *>(frg_output.data()); - auto compute_frgs = reinterpret_cast *>(frg_compute.data()); - - Array qpvscale_rcps = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate>{}(frg_sf); - return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcp); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_ups = cutlass::NumericArrayConverter{}(frg_sf); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_ups); - } - }(); - - // norm_constant and qpvscale_rcps are all positive numbers. - auto acc_scales = cutlass::multiplies>{}(norm_constant, qpvscale_rcps); - - CUTLASS_PRAGMA_UNROLL - for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { - // Map INF to fp32::max - auto acc_scale = minimum_with_nan_propagation{}(acc_scales[sf_v], cutlass::platform::numeric_limits::max()); - // Convert to output type - output_frgs[sf_v] = cutlass::NumericArrayConverter{}(mul_array(compute_frgs[sf_v], acc_scale)); - } - return frg_output; - } -} -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// BlockScaleFactor Generation Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int SFVecSize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm100BlockScaleFactorRowStore { - static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or - size<1>(EpilogueTile{}) / SFVecSize == 2 or - size<1>(EpilogueTile{}) / SFVecSize == 4 or - size<1>(EpilogueTile{}) / SFVecSize == 8, - "Possible store in interleaved 4B aligned format"); - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - struct SharedStorage { }; - - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = nullptr; - ElementCompute const* norm_constant_ptr = nullptr; - NormalConstStrideMNL norm_constant_stride = {}; - }; - - using Params = Arguments; - - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (N % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorRowStore() { } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class CoordGTensor, - class ThrResidue, - class EpiTileCoordMN, - class ElementType - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) - GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - CoordGTensor tC_cSFD_, // (m,n) - ThrResidue residue_tC_cSFD_, // (m,n) - Params const* params_ptr_, - EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) - ElementType norm_constant_, - ElementType norm_constant_scaled_down_) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , epi_tile_coord_mn(epi_tile_coord_mn_){} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - EpiTileCoordMN epi_tile_coord_mn; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) - { - static_assert(FragmentSize % SFVecSize == 0, "Scale factor vector size should divide FragmentSize"); - constexpr int NumVecs = FragmentSize / SFVecSize; - Array frg_compute; - - auto input_frgs = reinterpret_cast const*>(frg_input.data()); - auto compute_frgs = reinterpret_cast *>(frg_compute.data()); - - Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) - - cutlass::multiplies mul; - cutlass::maximum_absolute_value_reduction, true> amax_reduction; - - cutlass::Array vec_maxs; - cutlass::Array pvscales; - // SF generation - CUTLASS_PRAGMA_UNROLL - for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { - compute_frgs[sf_v] = NumericArrayConverter{}(input_frgs[sf_v]); - /// Step1: get max across a vector - vec_maxs[sf_v] = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); - } - - /// Step2: Compute Scale - pvscales = cutlass::multiplies>{}(vec_maxs, norm_constant_scaled_down); - - tC_rSFD_frg(_0{}) = cutlass::NumericArrayConverter{}(pvscales); - - Tensor tCgSFD_flt = filter_zeros(tC_gSFD(_,_,_,_0{},_0{},get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n)); - Tensor tCrSFD_flt = filter_zeros(tC_rSFD); - constexpr auto MCL = decltype(max_common_layout(tCgSFD_flt, tCrSFD_flt)){}; - constexpr int V = cute::min(4, size(MCL)); - using VecType = uint_bit_t>; - Tensor tCgSFD_vec = recast(coalesce(tCgSFD_flt)); - Tensor tCrSFD_vec = recast(coalesce(tCrSFD_flt)); - Tensor tCcSFD_pred = tC_cSFD(_,_,_, epi_m, epi_n); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSFD_vec); i++){ - if (elem_less(tCcSFD_pred(i * SFVecSize * V), residue_tC_cSFD)) { - tCgSFD_vec(i) = tCrSFD_vec(i); - } - } - /// Step3: Compute quantized output values - return detail::compute_quantized_with_row_scalefactor(frg_compute, tC_rSFD_frg(_0{}), norm_constant); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; - tile_coord_l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); -#if 0 - if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ - print("epi_tile ");print(args.epi_tile); print("\n"); - print("mSFD ");print(mSFD); print("\n"); - print("gSFD ");print(gSFD); print("\n"); - print("tCgSFD ");print(tCgSFD); print("\n"); - print("tCrSFD ");print(tCrSFD); print("\n"); - print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); - print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); - } -#endif - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - args.tCcD, - args.residue_tCcD, - params_ptr, - epi_tile_coord_mn, - norm_constant, - norm_constant_scaled_down); - - } -}; - -template < - int SFVecSize, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm100BlockScaleFactorColStore { - - static_assert(size<0>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<0>(EpilogueTile{}) / SFVecSize == 1 or - size<0>(EpilogueTile{}) / SFVecSize == 2 or - size<0>(EpilogueTile{}) / SFVecSize == 4 or - size<0>(EpilogueTile{}) / SFVecSize == 8, - "Possible store in interleaved 4B aligned format"); - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - static constexpr int NumSyncWarps = SFVecSize == 64 ? 4 : 0; - static constexpr int NumSyncThreads = NumSyncWarps * NumThreadsPerWarp; - struct SharedStorage { - array_aligned smem_aux; - }; - - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - ElementCompute const* norm_constant_ptr = nullptr; - NormalConstStrideMNL norm_constant_stride = {}; - }; - - using Params = Arguments; - - // BlockScaleFactor generation is per batch or group - // For Ptr-Array GEMM and Grouped GEMM, ElementBlockScaleFactor is ElementType* - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (M % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorColStore] M-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorColStore() { } - - CUTLASS_HOST_DEVICE - Sm100BlockScaleFactorColStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) - , smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr = nullptr; - ElementCompute *smem_aux = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class STensor, - class CoordGTensor, - class ThrResidue, - class EpiTileCoordMN, - class ElementType - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - // Normally, we should use tile_shape_mnk to tile the gtensor. - // However, the SF gtensor could not be divisible by non-pow2 cta tile, so we use epi tile (pow2) to do tiling. - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) - GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - STensor&& sAmaxs_, // (NumSyncWarps) - CoordGTensor tC_cSFD_, // (m,n) - ThrResidue residue_tC_cSFD_, // (m,n) - Params const* params_ptr_, - EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) - ElementType norm_constant_, - ElementType norm_constant_scaled_down_) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , sAmaxs(cute::forward(sAmaxs_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , epi_tile_coord_mn(epi_tile_coord_mn_) {} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - STensor sAmaxs; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - EpiTileCoordMN epi_tile_coord_mn; - - CUTLASS_DEVICE - ElementCompute find_amax(ElementCompute max) { - // Overall idea: after TMEM_LOAD.32DP32bit pattern, each thread in the warp can load adjacent elements of a column into its private RF. - // Here we are using shuffle instructons to the amax value of the adjacent column elements. - // For VS16, t0~t15 would generate an amax, and t16~t31 would generate another one. - // For VS32, t0~t31 should generate an amax. - // For VS64, t0~t63 should generate an amax. We would first do the reduciton within a warp, - // and then use smem to do inter-warp reduction. - if constexpr (SFVecSize == 32) { - return cutlass::redux_abs_max_nan_propagation_sync_warp{}(max); - } - else if constexpr (SFVecSize == 16) { - return cutlass::redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31{}(max); - } - else if constexpr (SFVecSize == 64) { - // Get abs_max per warp - auto abs_max = cutlass::redux_abs_max_nan_propagation_sync_warp{}(max); - - // Switch the amax of adjacent warps - const bool leading_thread = (threadIdx.x % NumThreadsPerWarp) == 0; - const int warp_idx = threadIdx.x / NumThreadsPerWarp % 4; - auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(NumSyncThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - // Inter-warp reduction for VS=64 - // Only 4 * FP32 = 16 bytes smem is needed as we have 4 warps. - if (leading_thread) { - sAmaxs(warp_idx) = abs_max; - } - synchronize(); - // Switch data between two adjacent warps to do reduction - float tmp = sAmaxs(warp_idx^1); - synchronize(); - abs_max = cutlass::maximum_with_nan_propagation{}(abs_max,tmp); - return abs_max; - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported VecSize"); - } - } - - template - CUTLASS_DEVICE auto - compute_quantized_value(Array compute, Array sf) { - cutlass::multiplies> mul_array; - auto qpvscale_rcp = [&]() CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcps = cutlass::reciprocal_approximate>{}(sf); - return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcps); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericArrayConverter{}(sf); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); - } - }(); - // norm_constant and qpvscale_rcps[sf_v] are all positive numbers. - auto acc_scale = mul_array(norm_constant, qpvscale_rcp); - // Map INF to fp32::max - acc_scale = minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - return mul_array(compute, acc_scale); - } - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) - { - constexpr int NumVecs = 1; // each thread only compute 1 col scalefactors - Array frg_compute; - Array frg_output; - Array frg_scale_float; - Array frg_amax; - Array frg_scale; - - Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) - - cutlass::multiplies mul; - cutlass::multiplies> mul_array; - /// convert acc to Element Compute - auto compute_frgs = NumericArrayConverter{}(frg_input); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - /// Step1: get max across a vector - frg_amax[i] = find_amax(compute_frgs[i]); - } - - frg_scale_float = mul_array(frg_amax, norm_constant_scaled_down); - frg_scale = cutlass::NumericArrayConverter{}(frg_scale_float); - auto tC_cSFD_pred = tC_cSFD(_,_,_,epi_m,epi_n); - auto tC_gSFD_store = tC_gSFD(_,_,_,_,_,get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n); - for (int i=0; i < cute::ceil_div(FragmentSize, SFVecSize); i++) { - int idx = i * SFVecSize + threadIdx.x % SFVecSize; - if (idx < FragmentSize && elem_less(tC_cSFD_pred(idx), residue_tC_cSFD)) { - UnderlyingElementBlockScaleFactor tmp = frg_scale[idx]; - // Store the (EpilogueTile / SFVecSize) elements. - tC_gSFD_store(idx) = tmp; - } - } - - /// Step3: Compute quantized output values - if constexpr (cute::sizeof_bits_v == 4) { - return compute_quantized_value(compute_frgs, frg_scale); // ElementCompute - } - else { - // 6bits or 8bits output. - compute_frgs = compute_quantized_value(compute_frgs, frg_scale); - frg_output = cutlass::NumericArrayConverter{}(compute_frgs); - return frg_output; // ElementOutput - } - - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; - tile_coord_l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - //Tensor gSFD = local_tile(mSFD, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); - // Normally, we should use tile_shape_mnk to tile the mSFD tensor. However, we could not do it for non-pow2 cta tile with vectorsize = 32. - // For scale factor, 128x4 elements are stored in a basic block, and the layout of mSFD is ((_32,_4,int),(_32,_4,int),int):((_16,_4,int),(_0,_1, int),int) - // If we tiled it using tile_shape_mnk(128, 192), the N mode would encounter shape_div failure because (32, 4) could not be divisible by 192. - // Therefore, switching to using pow2 epilogue tile. - static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); - - Tensor sAmaxs = make_tensor(make_smem_ptr(smem_aux), make_layout(_4{})); -#if 0 - if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ - print("mSFD ");print(mSFD); print("\n"); - print("gSFD ");print(gSFD); print("\n"); - print("tCgSFD ");print(tCgSFD); print("\n"); - print("tCrSFD ");print(tCrSFD); print("\n"); - print("args.tCcD ");print(args.tCcD); print("\n"); - print("args.residue_tCcD ");print(args.residue_tCcD); print("\n"); - print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); - print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); - } -#endif - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - cute::move(sAmaxs), - args.tCcD, - args.residue_tCcD, - params_ptr, - epi_tile_coord_mn, - norm_constant, - norm_constant_scaled_down); - } -}; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp deleted file mode 100644 index b769b1f0fbe2aa78f0ee97da442fb61c1aa49cc8..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp +++ /dev/null @@ -1,1593 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -/*! \file - \brief Fusion callbacks specializations for the SM120 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Sm120 Tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// D = alpha * acc + beta * C -// With BlockScaleFactor Generation. -// 1. Find max of 32 F32 elements -// 2. Convert the max to UE8 (or UE4M3) and store the result. -// 3. Convert the UE8 (or UE4M3) back to F32 scale. -// 4. Reciprocal of F32 scale with MUFU. -// 5. Multiply each F32 element with the above reciprocal, then convert to ElementD -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinearCombRowBlockScaleFactor = - Sm90EVT, // gen scalefactor - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinearCombRowBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm120LinearCombRowBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - - using Sm100Fusion = FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile - >; - using Operation = typename Sm100Fusion::Operation; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C + per-row bias -// with row blockScaled generation -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerRowBiasRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -// with row blockScaled generation -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm120LinCombPerRowBiasEltActRowBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C + per_col bias -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerColBiasRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per_col bias) -// with row blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasEltActRowBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm120LinCombPerColBiasEltActRowBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C -// with per column blockScaled generation -// 1. Find max of 32 F32 elements -// 2. Convert the max to UE8 (or UE4M3) and store the result. -// 3. Convert the UE8 (or UE4M3) back to F32 scale. -// 4. Reciprocal of F32 scale with MUFU. -// 5. Multiply each F32 element with the above reciprocal, then convert to ElementD -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinearCombColBlockScaleFactor = Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle>, - Sm90LinearCombination< - ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle> - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized< - StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>, - fusion::LinCombBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute,ElementBlockScaleFactor, - cutlass::layout::ColumnMajor, ElementSource, ElementScalar, RoundStyle>, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinearCombColBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - > { - - using Impl = Sm120LinearCombColBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; - - using Sm100Fusion = FusionCallbacks< - epilogue::Sm100TmaWarpSpecialized, - fusion::LinCombBlockScaleFactor, - CtaTileShapeMNK, - EpilogueTile - >; - using Operation = typename Sm100Fusion::Operation; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { - { - // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = alpha * acc + beta * C + per-Col bias -// with per column blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerColBiasColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per_col bias) -// with per column blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerColBiasEltActColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerColBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - using Impl = - Sm120LinCombPerColBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerColBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, - ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// D = activation(alpha * acc + beta * C + per-row bias) -// with per column blockScaled generation -template< - int StagesC, - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasEltActColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, - ElementCompute, ElementCompute, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > { - - - using Impl = - Sm120LinCombPerRowBiasEltActColBlockScaleFactor< - StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, - AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - - -// D = alpha * acc + beta * C + per-row bias -// with per column blockScaled generation -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombPerRowBiasColBlockScaleFactor = - Sm90EVT< - Sm120BlockScaleFactorColStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor, RoundStyle - >, // gen scalefactor - Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementCompute, ElementCompute, - ElementBias, ElementSource, ElementScalar, - AlignmentBias, RoundStyle - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120TmaWarpSpecialized, - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - > -{ - - using Impl = - Sm120LinCombPerRowBiasColBlockScaleFactor< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementBias, - ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - using Operation = - fusion::LinCombPerRowBiasBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::ColumnMajor, - ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -// Sm120 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - class Operation, - class CtaTile_MNK, - class EpilogueTile_MN, - class... Args -> -struct FusionCallbacks< - epilogue::Sm120PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... -> : FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args... - > { - using FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - Operation, - CtaTile_MNK, - EpilogueTile_MN, - Args...>::FusionCallbacks; -}; - -// For Ptr-Array and Grouped GEMM -// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinearCombRowBlockScaleFactorPtrArray = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor *, RoundStyle - >, // gen scalefactor - Sm90LinearCombinationPtrArray< ElementCompute, ElementCompute, - ElementSource, ElementScalar, RoundStyle - > // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120PtrArrayTmaWarpSpecialized, - fusion::LinCombBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinearCombRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - > { - - using Impl = - Sm120LinearCombRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - >; - - using Operation = - fusion::LinCombBlockScaleFactor< - SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - - operator typename Impl::Arguments() const { - return - { - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - - -// For Ptr-Array and Grouped GEMM -// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group -// With Row BlockScaleFactor Generation, separate tensors per batch/group. -template< - int SFVecsize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm120LinCombEltActRowBlockScaleFactorPtrArray = - Sm90EVT< - Sm120BlockScaleFactorRowStore< - SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, - ElementCompute, ElementBlockScaleFactor *, RoundStyle - >, // gen scalefactor - Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - int SFVecSize, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm120PtrArrayTmaWarpSpecialized, - fusion::LinCombEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm120LinCombEltActRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - > { - - using Impl = - Sm120LinCombEltActRowBlockScaleFactorPtrArray< - SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, - typename cutlass::detail::get_unpacked_element_type::type, - ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle - >; - - using Operation = - fusion::LinCombEltActBlockScaleFactor< - ActivationFn, SFVecSize, ElementOutput, ElementCompute, - ElementBlockScaleFactor, cutlass::layout::RowMajor, - ElementSource, ElementScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; - - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - using StrideNormConst = Stride<_0,_0,int64_t>; - ElementCompute const* norm_constant_ptr = nullptr; - StrideNormConst dNormConst = {_0{}, _0{}, 0}; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp deleted file mode 100644 index e72e971bd8d99f87a2528af3c1dbd27366298ef5..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp +++ /dev/null @@ -1,899 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -/*! \file - \brief Visitor tree store operations for the SM120 TMA warp-specialized (ws) epilogue -*/ - - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cute/tensor.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// BlockScaleFactor Generation Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int SFVecSize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm120BlockScaleFactorRowStore { - - static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or - size<1>(EpilogueTile{}) / SFVecSize == 2 or - size<1>(EpilogueTile{}) / SFVecSize == 4 or - size<1>(EpilogueTile{}) / SFVecSize == 8, - "Possible store in interleaved 4B aligned format"); - - static constexpr int NumWarpgroups = 2; - static constexpr int NumSyncWarps = NumWarpsPerWarpGroup * NumWarpgroups; - static constexpr int NumQuadsPerWarp = 8; - static constexpr int NumSyncQuads = NumSyncWarps * NumQuadsPerWarp; - struct SharedStorage { - array_aligned smem_aux; - }; - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = {}; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - ElementCompute const* norm_constant_ptr = {}; - NormalConstStrideMNL norm_constant_stride = {}; - }; - - using Params = Arguments; - - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (N % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm120BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorRowStore() { } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) - , smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr = nullptr; - ElementCompute *smem_aux = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class STensor, - class CoordGTensor, - class ThrResidue, - class TileCoordMN, - class ElementType, - class TiledCopy_ - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, - GTensor&& tC_gSFD_, - STensor&& sAmaxs_, - CoordGTensor tC_cSFD_, - ThrResidue residue_tC_cSFD_, - Params const* params_ptr_, - TileCoordMN tile_coord_mn_, - ElementType norm_constant_, - ElementType norm_constant_scaled_down_, - int thread_idx_, - TiledCopy_ const&) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , sAmaxs(cute::forward(sAmaxs_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , tile_coord_mn(tile_coord_mn_) - , thread_idx(thread_idx_) {} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - STensor sAmaxs; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - TileCoordMN tile_coord_mn; - int thread_idx; - static constexpr int NumCollaboratingThreads = decltype(size(TiledCopy_{}))::value; - static_assert(NumCollaboratingThreads % NumThreadsPerWarpGroup == 0); - static constexpr int NumCollaboratingWarpGroups = NumCollaboratingThreads / NumThreadsPerWarpGroup; - static_assert(NumCollaboratingWarpGroups == 1 || NumCollaboratingWarpGroups == 2, - "SM120 epilogue currently only supports one or two warp groups collaborating."); - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) { - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(SmemTensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - /* - Accumulator fragments are distributed across quads in different warps. - For SFVector = 16, we have: - - 8 elements 8 elements 8 elements 8 elements - <----------------><-----------------><-----------------><-----------------> - Warp 0 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 Warp 4 Quad 7 - Warp 0 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 Warp 4 Quad 7 - - - - - - In this case, row-wise scale factors are cooperatively reduced across 4 - threads from 1 quad in 1 warp. Each quad computes its own, local absolute - maximum without communicating with other warps through shared memory. - - For SFVector = 32, we have: - 8 elements 8 elements 8 elements 8 elements - <----------------><-----------------><-----------------><-----------------> - Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - - - - - - For SFVector = 64, we have: - 8 elements 8 elements 8 elements 8 elements - <----------------><-----------------><-----------------><-----------------> - Warp 0 Quad 0 Warp 2 Quad 0 Warp 4 Quad 0 Warp 6 Quad 0 - Warp 0 Quad 1 Warp 2 Quad 1 Warp 4 Quad 1 Warp 6 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 2 Quad 7 Warp 4 Quad 7 Warp 6 Quad 7 - Warp 0 Quad 0 Warp 2 Quad 0 Warp 4 Quad 0 Warp 6 Quad 0 - Warp 0 Quad 1 Warp 2 Quad 1 Warp 4 Quad 1 Warp 6 Quad 1 - ... ... ... ... - Warp 0 Quad 7 Warp 2 Quad 7 Warp 4 Quad 7 Warp 6 Quad 7 - - - - Thus, rowwise scale factors are cooperatively reduced across 8 threads - from two quads in two warps. Each quad first computes its own, local - absolute maximum and then shares this with the corresponding quad in the - other warp. In this case, a reduction through shared memory is needed. - - For a non-cooperative epilogue (in which each warpgroup computes a - separate tile), the pattern is the same as that above, except that warps 0 - and 2 are in the same row, and 1 and 3 are in the same row, and warps 4-7 - are not included. - */ - - // Accumulator fragments consist of two elements from two different rows of a 16x8 MMA output - static constexpr int ColsPerThreadAccFrag = 2; - static constexpr int RowsPerThreadAccFrag = 2; - static_assert(FragmentSize == - (ColsPerThreadAccFrag * RowsPerThreadAccFrag)); - - static constexpr int NumThreadsPerQuad = 4; - static_assert(SFVecSize == 16 || SFVecSize == 32 || SFVecSize == 64, "SF vector size must be either 16, 32 or 64."); - // A quad from two or four warps participate in computing each scale factor. - constexpr int WarpsPerSF = SFVecSize / 16; - static_assert(WarpsPerSF == 1 || WarpsPerSF == 2 || WarpsPerSF == 4, "Only one, two or four warps are allowed in reduction."); - - constexpr bool IsInterWarpReductionNeeded = (WarpsPerSF != 1); - - // Number of fragments for each thread that are needed for computing a scale factor - static constexpr int AccFragsPerSF = SFVecSize / (ColsPerThreadAccFrag * NumThreadsPerQuad * WarpsPerSF); - static_assert(size<2>(visit_results) % AccFragsPerSF == 0, - "Fragments along N mode must be a multiple of the number of accumulator fragments needed per SF"); - - auto warp_idx = thread_idx / NumThreadsPerWarp; - auto warpgroup_idx = thread_idx / NumThreadsPerWarpGroup; - auto quad_idx_in_warp = (thread_idx % NumThreadsPerWarp) / NumThreadsPerQuad; - auto thread_idx_in_quad = thread_idx % NumThreadsPerQuad; - - cutlass::maximum_absolute_value_reduction amax_op; - cutlass::multiplies mul; - - Tensor tC_rSFD_flt = filter_zeros(tC_rSFD); - - auto synchronize = [&] () { - cutlass::arch::NamedBarrier::sync(NumCollaboratingThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - }; - - CUTLASS_PRAGMA_UNROLL - for (int sf_id = 0; sf_id < size(tC_rSFD_flt); ++sf_id) { - - auto coord = idx2crd(sf_id, tC_rSFD_flt.shape()); - auto row_in_acc = get<0,1,1>(coord); - auto row = crd2idx(get<1>(coord), get<1>(tC_rSFD_flt.shape())); - auto sf = crd2idx(get<2>(coord), get<2>(tC_rSFD_flt.shape())); - - // - // Compute amax for this scale factor - // - ElementCompute amax{0}; - - // Compute amax among vals owned by this thread for this vector - auto acc_frag_row = row_in_acc * RowsPerThreadAccFrag; - auto acc_frag_start_for_sf = sf * AccFragsPerSF; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccFragsPerSF; ++i) { - auto acc_frg = visit_results(0, row, acc_frag_start_for_sf + i); - amax = amax_op(amax, acc_frg[acc_frag_row]); - amax = amax_op(amax, acc_frg[acc_frag_row + 1]); - } - - // At this point, each thread has computed the amax of the values that it owns for this SF vector. - // We now need to compute the amax across threads. Because the TiledMMA uses an MmaThrLayout of <4,1,1>, - // we know that all fragments in this row will belong to threads in this warp. Furthermore, because - // SM120 narrow-precision MMAs have 16x8 output size with a quad owning two rows, we know that a quad - // will own all of the elements to be reduced via amax. Therefore, we can use warp shuffle intrinsics - // among threads in one quad to compute the amax. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < 3; ++i) { - auto amax_other = __shfl_xor_sync(0xffffffff, amax, i); - amax = amax_op(amax, amax_other); - } - - if constexpr (IsInterWarpReductionNeeded) { - // At this point, all threads in the quad have the amax for the elements of the accumulator owned by its quad - // that should be used in computing the amax for this SF. Threads 0 in each quad of warps 0 and 2 - // (similarly, 1 and 3) now exchange amaxes to compute the final amax. - if (thread_idx_in_quad == 0) { - sAmaxs(quad_idx_in_warp, warp_idx) = amax; - } - synchronize(); - - // Get the amax broadcasted by the warp with which we share. - // Work on 4 warps per SFD generation - if constexpr (WarpsPerSF == 4) { - if constexpr (NumCollaboratingWarpGroups == 2) { - // This implementation assumes warp layout 2 x 4. - // For cooperative kernels (NumCollaboratingWarpGroups=2), - // warp 0 shares with 2 / 4 / 6, warp 1 shares with 3 / 5/ 7. - auto amax_other2 = sAmaxs(quad_idx_in_warp, warp_idx ^ 2); - auto amax_other4 = sAmaxs(quad_idx_in_warp, warp_idx ^ 4); - auto amax_other6 = sAmaxs(quad_idx_in_warp, warp_idx ^ 6); - synchronize(); - amax = amax_op(amax, amax_other2); - amax = amax_op(amax, amax_other4); - amax = amax_op(amax, amax_other6); - } - else { - static_assert(cutlass::detail::dependent_false, "Unsupported warp layout."); - } - } - // Work on 2 warps per SFD generation - else if constexpr(WarpsPerSF == 2) { - // For cooperative kernels (NumCollaboratingWarpGroups=2), 0 shares - // with 4, 1 shares with 5, etc. For non-cooperative kernels - // (NumCollaboratingWarpGroups=1), 0 shares with 2, 1 shares with 3. - auto amax_other = sAmaxs( - quad_idx_in_warp, warp_idx ^ (1 << NumCollaboratingWarpGroups)); - synchronize(); - amax = amax_op(amax, amax_other); - } - } - - ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); - tC_rSFD_flt(coord) = qpvscale; - - // - // Apply the scale factor to the output - // - ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); - } - }(); - - ElementCompute acc_scale = mul(norm_constant, qpvscale_rcp); - acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - - // Compute quantized output values - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccFragsPerSF; ++i) { - auto acc_frag = visit_results(0, row, acc_frag_start_for_sf + i); - visit_results(0, row, acc_frag_start_for_sf + i)[acc_frag_row ] = mul(acc_frag[acc_frag_row], acc_scale); - visit_results(0, row, acc_frag_start_for_sf + i)[acc_frag_row + 1] = mul(acc_frag[acc_frag_row + 1], acc_scale); - } - } // sf - - // Since scale factors are computed cooperatively across two quads from two warps, we only need one thread from the - // set of 8 cooperating threads to write out the data. We do this with thread 0 in each quad of the first warp that collaborates. - bool write_sf = (thread_idx_in_quad == 0); - if constexpr (NumCollaboratingWarpGroups == 2) { - // For cooperative kernels (NumCollaboratingWarpGroups=2), 0 shares with 4, 1 shares with 5, etc. - // Thus, only the warps in the first warpgroup need to write out scale factors. - if constexpr (IsInterWarpReductionNeeded) { - write_sf &= warp_idx < NumWarpsPerWarpGroup; - } - } - else { - if constexpr (IsInterWarpReductionNeeded) { - // When non-cooperative kernels apply inter warp reduce, they are with - // SF output rule as below : - // 1. warp 0 shares with 2 and 1 shares with 3 within each warpgroup. - // 2. warps 0 and 1 of the first warpgroup and 4 and 5 of the second - // warpgroup need to write output sf. - write_sf &= ((warp_idx < 2) || (warpgroup_idx == 1 && warp_idx < 6)); - } - } - - if (write_sf && elem_less(tC_cSFD(_0{}, _0{}, _0{}, epi_m, epi_n), residue_tC_cSFD)) { - copy_aligned(tC_rSFD, tC_gSFD(_, _, _, _0{}, _0{}, get<0>(tile_coord_mn) + epi_m, get<1>(tile_coord_mn) + epi_n)); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[l]; - l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - - static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); - - Tensor sAmaxs = make_tensor( - make_smem_ptr(smem_aux), - make_layout(make_shape(Int{}, Int{})) - ); - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - cute::move(sAmaxs), - args.tCcD, - args.residue_tCcD, - params_ptr, - tile_coord_mn, - norm_constant, - norm_constant_scaled_down, - args.thread_idx, - args.tiled_copy); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int SFVecSize, - class EpilogueTile, - class CtaTileShapeMNK, - int FragmentSize, - class ElementOutput, - class ElementCompute, - class ElementBlockScaleFactor, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -struct Sm120BlockScaleFactorColStore { - - static_assert(size<0>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); - static_assert(size<0>(EpilogueTile{}) / SFVecSize == 1 or - size<0>(EpilogueTile{}) / SFVecSize == 2 or - size<0>(EpilogueTile{}) / SFVecSize == 4, - "Possible store in interleaved 4B aligned format"); - - static constexpr int NumWarpgroups = 2; - static constexpr int NumSyncWarps = NumWarpsPerWarpGroup * NumWarpgroups; - static constexpr int NumThreadsPerQuad = 4; - static constexpr int NumSyncElementsCrossWarp = NumSyncWarps * NumThreadsPerQuad; - struct SharedStorage { - array_aligned smem_aux; - }; - - using NormalConstStrideMNL = Stride<_0,_0,int64_t>; - - struct Arguments { - ElementBlockScaleFactor* ptr_scale_factor = {}; - // A matrix wide constant value to scale the output matrix - // Avoids generating small FP4 values. - ElementCompute const* norm_constant_ptr = {}; - NormalConstStrideMNL norm_constant_stride = {}; - }; - using Params = Arguments; - - using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - bool implementable = (M % SFVecSize == 0); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm120BlockScaleFactorColStore] N-dim should be divisible by SFVecSize.\n"); - } - return implementable; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorColStore() { } - - CUTLASS_HOST_DEVICE - Sm120BlockScaleFactorColStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) - , smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr = nullptr; - ElementCompute *smem_aux = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class GTensor, - class STensor, - class CoordGTensor, - class ThrResidue, - class TileCoordMN, - class ElementType, - class TiledCopy_ - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rSFD_, - GTensor&& tC_gSFD_, - STensor&& sAmaxs_, - CoordGTensor tC_cSFD_, - ThrResidue residue_tC_cSFD_, - Params const* params_ptr_, - TileCoordMN tile_coord_mn_, - ElementType norm_constant_, - ElementType norm_constant_scaled_down_, - int thread_idx_, - TiledCopy_ const&) - : tC_rSFD(cute::forward(tC_rSFD_)) - , tC_gSFD(cute::forward(tC_gSFD_)) - , sAmaxs(cute::forward(sAmaxs_)) - , tC_cSFD(tC_cSFD_) - , residue_tC_cSFD(residue_tC_cSFD_) - , params_ptr(params_ptr_) - , norm_constant(norm_constant_) - , norm_constant_scaled_down(norm_constant_scaled_down_) - , tile_coord_mn(tile_coord_mn_) - , thread_idx(thread_idx_) {} - - static_assert(is_same_v); - RTensor tC_rSFD; - GTensor tC_gSFD; - STensor sAmaxs; - CoordGTensor tC_cSFD; - ThrResidue residue_tC_cSFD; - Params const* params_ptr; - ElementCompute norm_constant; - ElementCompute norm_constant_scaled_down; - TileCoordMN tile_coord_mn; - int thread_idx; - static constexpr int NumCollaboratingThreads = decltype(size(TiledCopy_{}))::value; - static_assert(NumCollaboratingThreads % NumThreadsPerWarpGroup == 0); - static constexpr int NumCollaboratingWarpGroups = NumCollaboratingThreads / NumThreadsPerWarpGroup; - static_assert(NumCollaboratingWarpGroups == 2, - "SM120 epilogue currently only supports two warp groups collaborating."); - static_assert(SFVecSize == 16 || SFVecSize == 32 || SFVecSize == 64, "SF vector size must be either 16, 32 or 64."); - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, - int epi_v, - int epi_m, - int epi_n, - Array const& frg_input) { - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(SmemTensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - /* - Accumulator fragments are distributed across threads/quads in different warps. For column major, the - reduction happens along M dimension. For SFVector = 32, we have: - - 8 elements 8 elements 8 elements 8 elements - + <----------------------><----------------------><----------------------><----------------------> - | Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - | Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - | ... ... ... ... - 1 | Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - 6 | Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 - | Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 - | ... ... ... ... - + Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 - | Warp 1 Quad 0 Warp 5 Quad 0 Warp 1 Quad 0 Warp 5 Quad 0 - | Warp 1 Quad 1 Warp 5 Quad 1 Warp 1 Quad 1 Warp 5 Quad 1 - 1 | ... ... ... ... - 6 | Warp 1 Quad 7 Warp 5 Quad 7 Warp 1 Quad 7 Warp 5 Quad 7 - | Warp 1 Quad 0 Warp 5 Quad 0 Warp 1 Quad 0 Warp 5 Quad 0 - | Warp 1 Quad 1 Warp 5 Quad 1 Warp 1 Quad 1 Warp 5 Quad 1 - | ... ... ... ... - | Warp 1 Quad 7 Warp 5 Quad 7 Warp 1 Quad 7 Warp 5 Quad 7 - - - - In this case, colum-wise scale factors are cooperatively reduced across 8 threads from 2 warps. - Each column first computes its own, local absolute maximum and then shares this with the - corresponding threads in the other warp. In this case, a reduction through shared memory is needed. - - For SFVector = 64, the reduction happens inside 4 warps: warp 0/1/2/3 and warp 4/5/6/7. - */ - - // Accumulator fragments consist of two elements from two different columns of a 16x8 MMA output - static constexpr int RowsPerThreadAccFrag = 2; - static constexpr int ColsPerThreadAccFrag = 2; - static_assert(FragmentSize == (ColsPerThreadAccFrag * RowsPerThreadAccFrag)); - - static constexpr int NumThreadsPerCol = NumThreadsPerWarp / NumThreadsPerQuad; - constexpr int WarpsPerSF = SFVecSize / NumThreadsPerCol / ColsPerThreadAccFrag; - static_assert(WarpsPerSF == 1 || WarpsPerSF == 2 || WarpsPerSF == 4, "Only one, two or four warps are allowed in reduction."); - - auto warp_idx = thread_idx / NumThreadsPerWarp; - auto thread_idx_in_warp = thread_idx % NumThreadsPerWarp; - - cutlass::maximum_absolute_value_reduction amax_op; - cutlass::multiplies mul; - - auto synchronize = [&] () { - // When WarpsPerSF equals 1, data processing is inside warp, there is no needs to have the sync. - static constexpr bool NoSyncNeeded = (WarpsPerSF == 1); - if(NoSyncNeeded) - return; - cutlass::arch::NamedBarrier::sync(NumCollaboratingThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - }; - - CUTLASS_PRAGMA_UNROLL - for(int mma_in_epi = 0; mma_in_epi < size<1>(tC_rSFD)*size<2>(tC_rSFD); ++mma_in_epi) { - - CUTLASS_PRAGMA_UNROLL - for (int sf_id = 0; sf_id < ColsPerThreadAccFrag; ++sf_id) { - - // - // Compute amax for this scale factor - // - ElementCompute amax{0}; - - // Compute amax among vals owned by this thread for this vector - auto acc_frg = visit_results(mma_in_epi); - amax = amax_op(amax, acc_frg[sf_id]); - amax = amax_op(amax, acc_frg[sf_id + ColsPerThreadAccFrag]); - - // At this point, each thread has computed the amax of the values that it owns for this SF vector. - // We now need to compute the amax across threads. Because SM120 narrow-precision MMAs have 16x8 output - // size with a quad owning two rows, we know that 8 threads in one column will own all of the 16 elements - // to be reduced via amax. Therefore, we can use warp shuffle intrinsics among threads to compute the amax. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < NumThreadsPerCol; ++i) { - auto amax_other = __shfl_xor_sync(0xffffffff, amax, (i * NumThreadsPerQuad)); - amax = amax_op(amax, amax_other); - } - - // At this point, all threads in the quad have the amax for the elements of the accumulator owned by its - // threads that should be used in computing the amax for this SF. - if (thread_idx_in_warp < NumThreadsPerQuad && WarpsPerSF != 1) { - sAmaxs(thread_idx_in_warp, warp_idx) = amax; - } - - synchronize(); - - // Get the amax broadcasted by the warp with which we share. - // For cooperative kernels, when scale factor vector size is 32 (WarpsPerSF equals 2), - // warp 0 shares with 1, warp2 shares with 2, etc. - // When vector size is 64 (WarpsPerSF equals 4), warp 0 shares with 1/2/3, and 4 shares with 5/6/7. - // When vector size is 16, no needs to swap between warps. - if constexpr (2 == WarpsPerSF) { - auto amax_other = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 1); - amax = amax_op(amax, amax_other); - } - else if constexpr (4 == WarpsPerSF) { - auto amax_other1 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 1); - auto amax_other2 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 2); - auto amax_other3 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 3); - amax = amax_op(amax, amax_other1); - amax_other2 = amax_op(amax_other2, amax_other3); - amax = amax_op(amax, amax_other2); - } - synchronize(); - - ElementCompute pvscale = mul(amax, norm_constant_scaled_down); - UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); - filter(tC_rSFD)(sf_id + mma_in_epi*ColsPerThreadAccFrag) = qpvscale; - - // - // Apply the scale factor to the output - // - ElementCompute qpvscale_rcp = [&]() { - if constexpr (cute::is_same_v) { - // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. - auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); - return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); - } - else { - // UE4M3: Do the rcp in fp32 data type. - auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); - return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); - } - }(); - - ElementCompute acc_scale = mul(norm_constant, qpvscale_rcp); - acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); - - // Compute quantized output values - visit_results(mma_in_epi)[sf_id ] = mul(acc_frg[sf_id ], acc_scale); - visit_results(mma_in_epi)[sf_id + ColsPerThreadAccFrag] = mul(acc_frg[sf_id + ColsPerThreadAccFrag], acc_scale); - } // end for sf_id - } // end for mma_in_epi - - // Since scale factors are computed cooperatively across two or four warps, we only need one thread from the - // cooperating column threads group to write out the data. - bool write_sf = (thread_idx_in_warp < NumThreadsPerQuad); - if constexpr (2 == WarpsPerSF) { - // Output warp {0, 2, 4, 6}. - write_sf &= ((warp_idx & 0x1) == 0); - } - else if constexpr (4 == WarpsPerSF) { - // Output warp {0, 4}. - write_sf &= ((warp_idx & 0x3) == 0); - } - else if constexpr (1 == WarpsPerSF) { - // Output warp {0, 1, ..., 7}. Keep write_sf as is. - } - - if (write_sf && elem_less(tC_cSFD(_0{}, _0{}, _0{}, epi_m, epi_n), residue_tC_cSFD)) { - copy_aligned(tC_rSFD, tC_gSFD(_, _, _, _0{}, _0{}, get<0>(tile_coord_mn) + epi_m, get<1>(tile_coord_mn) + epi_n)); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; - UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; - // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group - if constexpr (!cute::is_same_v) { - ptr_scale_factor = params_ptr->ptr_scale_factor[l]; - l = 0; - } - else { - ptr_scale_factor = params_ptr->ptr_scale_factor; - } - - static_assert(size<0>(EpilogueTile{}) && ((size<0>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), - "Epilogue Tile N should be pow of 2"); - - auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); - Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), - Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); - - Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) - Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) - gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) - - auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); - - // Fetch and compute these during initialization - Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); - ElementCompute norm_constant = mNormConst(_0{},_0{},l); - ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); - ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); - ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); - - Tensor sAmaxs = make_tensor( - make_smem_ptr(smem_aux), - make_layout(make_shape(Int{}, Int{})) - ); - - return ConsumerStoreCallbacks( - cute::move(tCrSFD), - cute::move(tCgSFD), - cute::move(sAmaxs), - args.tCcD, - args.residue_tCcD, - params_ptr, - tile_coord_mn, - norm_constant, - norm_constant_scaled_down, - args.thread_idx, - args.tiled_copy); - - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp deleted file mode 100644 index 95e8208686ead6606040ee280023a7f5b879b07b..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ /dev/null @@ -1,2792 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/fusion/callbacks.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" - -#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Sm90EVT = Sm90TreeVisitor; - -// D = alpha * acc -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledAcc, - CtaTileShapeMNK, - EpilogueTile -> : Sm90EVT, - Sm90ScalarBroadcast>, - Sm90AccFetch - > { - using Impl = - Sm90EVT, - Sm90ScalarBroadcast>, - Sm90AccFetch - >; - using Operation = fusion::ScaledAcc; - - struct Arguments { - // Give a name and flat ordering to the fusion callback args - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - - // Conversion to the args expected by the visitor implementation - // to_underlying_arguments will implicitly call this - operator typename Impl::Arguments() const { - return - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C -template< - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinearCombination = - Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast>, // alpha - Sm90AccFetch // acc - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinearCombination, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombination; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch -template< - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinearCombinationPtrArray = - Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcastPtrArray>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcastPtrArray>, // alpha - Sm90AccFetch // acc - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - fusion::LinearCombination, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombination; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C) -template< - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombEltAct = - Sm90EVT, // activation(beta * C + (alpha * acc)) - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombEltAct, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombEltAct { - - using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltAct; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args: activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C), where beta and alpha can be vectors for each batch -template< - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombEltActPtrArray = - Sm90EVT, // activation(beta * C + (alpha * acc)) - Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - int NumEpilogueWarpGroups, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90PtrArrayTmaWarpSpecialized, - fusion::LinCombEltAct, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombEltActPtrArray { - - using Impl = Sm90LinCombEltActPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltAct; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - ElementScalar const* const* alpha_ptr_array = nullptr; - ElementScalar const* const* beta_ptr_array = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args: activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per-row bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast>, // alpha - Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBias, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { - using Impl = Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - using Operation = fusion::LinCombPerRowBias< - ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = alpha * acc + beta * C + per-column bias -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast>, // alpha - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerColBias, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { - using Impl = Sm90LinCombPerColBias< - StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - using Operation = fusion::LinCombPerColBias< - ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowBiasEltAct = - Sm90EVT, - Sm90LinCombPerRowBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-column bias) -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColBiasEltAct = - Sm90EVT, - Sm90LinCombPerColBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(alpha * acc + beta * C + per-row bias) -// Aux = alpha * acc + beta * C + per-row bias) -template< - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerRowBiasEltActAux = - Sm90EVT, - Sm90EVT, - Sm90LinCombPerRowBias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBiasEltActAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90LinCombPerRowBiasEltActAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerRowBiasEltActAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerRowBiasEltActAux< - GmemLayoutTagAux, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - return - { // unary op : activation(store(beta * C + (alpha * acc + bias))) - { // unary op : store(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// D = activation(alpha * acc + beta * C + per_col bias) -// Aux = alpha * acc + beta * C + per_col bias) -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombPerColBiasEltActAux = - Sm90EVT, - Sm90EVT, - Sm90LinCombPerColBias - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerColBiasEltActAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90LinCombPerColBiasEltActAux< - StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombPerColBiasEltActAux< - StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombPerColBiasEltActAux< - GmemLayoutTagAux, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - return - { // unary op : activation(store(beta * C + (alpha * acc + bias))) - { // unary op : store(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = per-row alpha * acc + per-row beta * C + per-row bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerRowLinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // beta, dynamic scalar/vector broadcast - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast - Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias - > - >; - -// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerRowLinCombPerRowBiasEltAct = - Sm90EVT, - Sm90PerRowLinCombPerRowBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - int AlignmentScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::PerRowLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90PerRowLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - > { - - using Impl = - Sm90PerRowLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - using Operation = - fusion::PerRowLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - - struct Arguments { - using StrideAlpha = Stride; - using StrideBeta = Stride; - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - StrideAlpha dAlpha = {bool(1), _0{}, 0}; - StrideBeta dBeta = {bool(1), _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {beta_ptr, beta, dBeta}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {alpha_ptr, alpha, dAlpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = per-col alpha * acc + per-col beta * C + per-column bias -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerColLinCombPerColBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - >; - -// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) -template< - int StagesC, - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerColLinCombPerColBiasEltAct = - Sm90EVT, - Sm90PerColLinCombPerColBias - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - int AlignmentScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::PerColLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90PerColLinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - > { - - using Impl = - Sm90PerColLinCombPerColBiasEltAct< - StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - using Operation = - fusion::PerColLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,bool,int64_t>; - using StrideBeta = Stride<_0,bool,int64_t>; - StrideAlpha dAlpha = {_0{}, bool(1), 0}; - StrideBeta dBeta = {_0{}, bool(1), 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // unary op : activation(beta * C + (alpha * acc + bias)) - { // ternary op : beta * C + (alpha * acc + bias) - {beta_ptr, beta, dBeta}, // leaf args : beta - {}, // leaf args : C - { // ternary op : alpha * acc + bias - {alpha_ptr, alpha, dAlpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C -template< - class CtaTileShapeMNK, - class EpilogueTile, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - int AlignmentScalar = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90PerColResAddPerColBiasEltAct = - Sm90EVT, // beta * C + activation(alpha * acc + bias) - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast - Sm90SrcFetch, // C - Sm90EVT, // activation(alpha * acc + bias) - Sm90EVT, // alpha * acc + bias - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - > - >; - - template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - int AlignmentScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::PerColResAddPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90PerColResAddPerColBiasEltAct< - CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - > { - - using Impl = - Sm90PerColResAddPerColBiasEltAct< - CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - using Operation = - fusion::PerColResAddPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,bool,int64_t>; - using StrideBeta = Stride<_0,bool,int64_t>; - StrideAlpha dAlpha = {_0{}, bool(1), 0}; - StrideBeta dBeta = {_0{}, bool(1), 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + activation(alpha * acc + bias) - {beta_ptr, beta, dBeta}, // leaf args : beta - {}, // leaf args : C - { // unary op : activation(alpha * acc + bias) - { // ternary op : alpha * acc + bias - {alpha_ptr, alpha, dAlpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; - -// We only apply the scaling factor if output is fp8 -template -struct ScaleOutOp { template using Op = cutlass::first; }; -template <> -struct ScaleOutOp { template using Op = cutlass::multiplies; }; -template <> -struct ScaleOutOp { template using Op = cutlass::multiplies; }; - -template -using amax = cutlass::maximum_absolute_value_reduction; // propogate nans - -}; // end namespace detail - -// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha - Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias - > - >; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltAct = - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // activation(Z) - // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias - Sm90ScaledLinCombPerRowBias - >, - Sm90ScalarBroadcast // scale_d - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d - { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias -template< - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBias = - Sm90EVT, // beta * C + (alpha * acc + bias) - Sm90ScalarBroadcast, 2>, // scale_c * beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias - Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha - Sm90AccFetch, // acc - Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias - > - >; - -// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias -// if D is fp8 -// D = scale_d * activation(Z) -// else -// D = activation(Z) -template< - class CtaTileShapeMNK, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltAct = - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // activation(Z) - // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias - Sm90ScaledLinCombPerColBias - >, - Sm90ScalarBroadcast // scale_d - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile -> : Sm90ScaledLinCombPerColBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerColBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerColBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - operator typename Impl::Arguments() const { - return - { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d - { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - activation // unary args : activation - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z - -// fp8 aux specialization -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = - Sm90SplitTreeVisitor< - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90SplitTreeFetch // Z - > - >, - Sm90ScalarBroadcast // scale_d - >, - // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) - Sm90EVT, // store(Aux) - Sm90EVT, // Z * scale_aux - Sm90EVT, // amax_aux - Sm90SplitTreeFetch // Z - >, - Sm90ScalarBroadcast // scale_aux - > - > - >; - -// non-fp8 aux specialization -// lets us use some EVT specializations such as relu + uint1b_t aux -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90EVT, // Aux = Z - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias - > - > - >, - Sm90ScalarBroadcast // scale_d - >; - -// dispatcher -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, - Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle - >, - Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > ->; - - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementAmax, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerRowBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerRowBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerRowBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - ElementScalar scale_aux = ElementScalar(1); - ElementScalar const* scale_aux_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - ElementAmax* amax_D_ptr = nullptr; - ElementAmax* amax_aux_ptr = nullptr; - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - // Only compute amax_d if D is fp8 - ElementAmax* amax_D_ptr_ = nullptr; - if constexpr (detail::is_fp8_v) { - amax_D_ptr_ = amax_D_ptr; - } - - // Aux is fp8 -> DAG arguments - if constexpr (detail::is_fp8_v) { - typename Impl::Arguments args; - // always use structured binding to unpack DAG args since it may or may not be a tuple - auto& [Z_args, aux_args, D_args] = args; - - Z_args = - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - - D_args = - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - {}, // leaf args : Z - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - - aux_args = - { // unary op : store(Aux) - { // binary op : Z * scale_d or Z - { // unary op : reduce(Z) - {}, // leaf args : Z - {amax_aux_ptr} // unary args : reduce - }, // end unary op - {{scale_aux}, - {scale_aux_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies - }, // end binary op - {aux_ptr, dAux} // unary args : store - }; // end unary op - - return args; - } - - // Aux is not fp8 -> Tree arguments - else { - return - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - { // unary op : store(Z) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias - }, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d},{scale_d_ptr}}, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias -// if D is fp8 -// amax_d = max(abs(elements in activation(Z))) -// D = scale_d * activation(Z) -// else -// D = activation(Z) -// if Aux is fp8 -// amax_aux = max(abs(elements in Z)) -// Aux = scale_aux * Z -// else -// Aux = Z - -// fp8 aux specialization -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8 = - Sm90SplitTreeVisitor< - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias - Sm90ScaledLinCombPerColBias, - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90SplitTreeFetch // Z - > - >, - Sm90ScalarBroadcast // scale_d - >, - // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) - Sm90EVT, // store(Aux) - Sm90EVT, // Z * scale_aux - Sm90EVT, // amax_aux - Sm90SplitTreeFetch // Z - >, - Sm90ScalarBroadcast // scale_aux - > - > - >; - -// non-fp8 aux specialization -// lets us use some EVT specializations such as relu + uint1b_t aux -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8 = - // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d - Sm90EVT, // amax_d - Sm90EVT, // activation(Z) - Sm90EVT, // Aux = Z - // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerColBias - > - > - >, - Sm90ScalarBroadcast // scale_d - >; - -// dispatcher -template< - class CtaTileShapeMNK, - class EpilogueTile, - int StagesD, - class StrideAux, - class SmemLayoutAtom, - class CopyOpR2S, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementAmax = ElementCompute, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90ScaledLinCombPerColBiasEltActAmaxAux = conditional_t, - Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle - >, - Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > ->; - - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementAmax, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpR2S -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::ScaledLinCombPerColBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpR2S -> : Sm90ScaledLinCombPerColBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90ScaledLinCombPerColBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, - SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::ScaledLinCombPerColBiasEltActAmaxAux< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - ElementScalar scale_a = ElementScalar(1); - ElementScalar scale_b = ElementScalar(1); - ElementScalar scale_c = ElementScalar(1); - ElementScalar scale_d = ElementScalar(1); - ElementScalar const* scale_a_ptr = nullptr; - ElementScalar const* scale_b_ptr = nullptr; - ElementScalar const* scale_c_ptr = nullptr; - ElementScalar const* scale_d_ptr = nullptr; - - ElementScalar scale_aux = ElementScalar(1); - ElementScalar const* scale_aux_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using StrideBias = Stride<_0,_1,int64_t>; - ElementBias const* bias_ptr = nullptr; - StrideBias dBias = {}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - ElementAmax* amax_D_ptr = nullptr; - ElementAmax* amax_aux_ptr = nullptr; - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - // Only compute amax_d if D is fp8 - ElementAmax* amax_D_ptr_ = nullptr; - if constexpr (detail::is_fp8_v) { - amax_D_ptr_ = amax_D_ptr; - } - - // Aux is fp8 -> DAG arguments - if constexpr (detail::is_fp8_v) { - typename Impl::Arguments args; - // always use structured binding to unpack DAG args since it may or may not be a tuple - auto& [Z_args, aux_args, D_args] = args; - - Z_args = - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias}, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }; // end ternary op - - D_args = - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - {}, // leaf args : Z - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d}, - {scale_d_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - - aux_args = - { // unary op : store(Aux) - { // binary op : Z * scale_d or Z - { // unary op : reduce(Z) - {}, // leaf args : Z - {amax_aux_ptr} // unary args : reduce - }, // end unary op - {{scale_aux}, - {scale_aux_ptr} - }, // leaf args : scale_d - {} // binary args : multiplies - }, // end binary op - {aux_ptr, dAux} // unary args : store - }; // end unary op - - return args; - } - - // Aux is not fp8 -> Tree arguments - else { - return - { // binary op : activation(Z) * scale_d or activation(Z) - { // unary op : reduce(activation(Z)) - { // unary op : activation(Z) - { // unary op : store(Z) - { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) - {{beta, scale_c}, - {beta_ptr, scale_c_ptr}, - {dBeta, {_0{}, _0{}, 0}} - }, // leaf args : (scale_c * beta) - {}, // leaf args : C - { // ternary op : (scale_a * scale_b * alpha) * acc + bias - {{alpha, scale_a, scale_b}, - {alpha_ptr, scale_a_ptr, scale_b_ptr}, - {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr, ElementBias(0), dBias - }, // leaf args : bias - {} // ternary args : multiply_add - }, // end ternary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, dAux} // unary args : store - }, // end unary op - activation // unary args : activation - }, // end unary op - {amax_D_ptr_} // unary args : reduce - }, // end unary op - {{scale_d},{scale_d_ptr}}, // leaf args : scale_d - {} // binary args : multiplies or first - }; // end binary op - } - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpS2R, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombDeEltAct = - Sm90EVT, // activation(beta * C + (alpha * acc), aux) - Sm90LinearCombination, // beta * C + (alpha * acc) - Sm90AuxLoad // aux - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementSource, - class ElementScalar, - int AlignmentAux, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpS2R -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombDeEltAct< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpS2R -> : Sm90LinCombDeEltAct< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - > { - - using Impl = - Sm90LinCombDeEltAct< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - >; - using Operation = - fusion::LinCombDeEltAct< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux const* aux_ptr = nullptr; - StrideAux dAux = {}; - - operator typename Impl::Arguments() const { - return - { // binary op : activation(beta * C + (alpha * acc), aux) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, ElementAux(0), dAux}, // leaf args : aux - activation // binary args : activation - }; // end binary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - class CtaTileShapeMNK, - class EpilogueTile, - int Stages, - class StrideAux, - class SmemLayoutAtom, - class CopyOpS2R, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux = ElementOutput, - class ElementBias = ElementOutput, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - int AlignmentAux = 128 / sizeof_bits_v, - int AlignmentBias = 128 / sizeof_bits_v, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombDeEltActDePerRowBias = - Sm90EVT, // Identity for final conversion - Sm90EVT, AlignmentBias>, - Sm90LinCombDeEltAct - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class GmemLayoutTagAux, - template class ActivationFn, - class ElementOutput, - class ElementCompute, - class ElementAux, - class ElementBias, - class ElementSource, - class ElementScalar, - int AlignmentAux, - int AlignmentBias, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class SmemLayoutAtom, - class CopyOpS2R -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombDeEltActDePerRowBias< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >, - CtaTileShapeMNK, - EpilogueTile, - SmemLayoutAtom, - CopyOpS2R -> : Sm90LinCombDeEltActDePerRowBias< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - > { - - using Impl = - Sm90LinCombDeEltActDePerRowBias< - CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - using Operation = - fusion::LinCombDeEltActDePerRowBias< - GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle - >; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - using ActivationArguments = typename Sm90Compute::Arguments; - ActivationArguments activation = ActivationArguments(); - - using StrideAux = cutlass::gemm::TagToStrideC_t; - ElementAux const* aux_ptr = nullptr; - StrideAux dAux = {}; - - using StrideBias = Stride<_1,_0,int64_t>; - ElementBias* dbias_ptr = nullptr; - StrideBias dDbias = {}; - - operator typename Impl::Arguments() const { - return - { // unary op : identity/convert - { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) - { // binary op : activation(beta * C + (alpha * acc), aux) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - {aux_ptr, ElementAux(0), dAux}, // leaf args : aux - activation // binary args : activation - }, // end binary op - {dbias_ptr, ElementCompute(0), dDbias} // unary args : reduce - }, // end unary op - {} // unary args : identity/convert - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// D = softmax(top_k(alpha * acc + beta * C)) -template< - int TopK, - int FragmentSize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinCombTopKSoftmaxCol = - Sm90EVT, // softmax(top_k(beta * C + (alpha * acc))) - Sm90LinearCombination // beta * C + (alpha * acc) - >; - -template < - int TopK, - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombTopKSoftmaxCol, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinCombTopKSoftmaxCol { - - using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombTopKSoftmaxCol; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - operator typename Impl::Arguments() const { - return - { // unary op: activation(beta * C + (alpha * acc)) - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }, // end ternary op - {} // unary args: activation - }; // end unary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Grouped Wgrad Conv -template< - class GroupsPerTile, - class ElementOutput, - class ElementCompute, - class ElementSource = ElementOutput, - class ElementScalar = ElementCompute, - FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest -> -using Sm90LinearCombinationGroupedWgrad = - Sm90EVT, // beta * C + (alpha * acc) - Sm90ScalarBroadcast>, // beta - Sm90SrcFetch, // C - Sm90EVT, // alpha * acc - Sm90ScalarBroadcast>, // alpha - Sm90AccFetchGroupedWgrad // acc - > - >; - -template < - int StagesC, - int StagesD, - int FragmentSize, - bool ReuseSmemC, - bool DelayTmaStore, - class ElementOutput, - class ElementCompute, - class ElementSource, - class ElementScalar, - FloatRoundStyle RoundStyle, - class CtaTileShapeMNK, - class EpilogueTile, - class GroupsPerTile -> -struct FusionCallbacks< - epilogue::Sm90TmaWarpSpecialized, - fusion::LinearCombinationGroupedWgrad, - CtaTileShapeMNK, - EpilogueTile -> : Sm90LinearCombinationGroupedWgrad::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - - using Impl = Sm90LinearCombinationGroupedWgrad::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombinationGroupedWgrad; - - struct Arguments { - ElementScalar alpha = ElementScalar(1); - ElementScalar beta = ElementScalar(0); - //ElementScalar groups = ElementScalar(1); - ElementScalar const* alpha_ptr = nullptr; - ElementScalar const* beta_ptr = nullptr; - - using StrideAlpha = Stride<_0,_0,int64_t>; - using StrideBeta = Stride<_0,_0,int64_t>; - StrideAlpha dAlpha = {_0{}, _0{}, 0}; - StrideBeta dBeta = {_0{}, _0{}, 0}; - - operator typename Impl::Arguments() const { - return - { // ternary op : beta * C + (alpha * acc) - {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta - {}, // leaf args : C - { // binary op : alpha * acc - {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha - {}, // leaf args : acc - {} // binary args : multiplies - }, // end binary op - {} // ternary args : multiply_add - }; // end ternary op - } - }; - - // Ctor inheritance - using Impl::Impl; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { -template > -struct get_element_aux { - using type = void; -}; - -template -struct get_element_aux> { - using type = typename FusionOpOrCallbacks::ElementAux; -}; - -template -struct get_element_aux, cute::void_t<>> { - using type = typename get_element_aux::type; -}; - -template -struct get_element_aux, cute::void_t::Operation>> { - private: - using Operation = typename FusionCallbacks::Operation; - public: - using type = typename get_element_aux::type; -}; -} // namespace cutlass:epilogue::fusion::detail - -template -using get_element_aux_t = typename detail::get_element_aux::type; - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp deleted file mode 100644 index ae63a7675c12dc4329374815da4d081a6bd885ee..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ /dev/null @@ -1,842 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree compute operations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/detail/helper_macros.hpp" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" -#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// N-nary Elementwise Compute Operation -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// The template argument provided for ComputeFn must be able to accept -// exactly one template parameter. In Standard C++, it's OK for -// ComputeFn to have other template parameters, as long as those have -// defaults. For example, the following struct Foo would work. -// -// template -// struct Foo { -// CUTLASS_HOST_DEVICE auto operator() (A a, B b); -// }; -// -// However, some compilers, such as Clang, require that the argument -// take _exactly_ one template parameter. This is nonstandard C++ -// behavior. One work-around for this case is to create a subclass -// with exactly one template parameter, and then use that subclass as -// the template argument. -// -// template -// struct FooHomogeneous : public Foo {}; -// -template< - template class ComputeFn, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class = void -> -struct Sm90Compute { -private: - using EmptyArguments = typename Sm90VisitorImpl<>::Arguments; - - template - struct ComputeArguments { - using type = EmptyArguments; - }; - - // partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters - template - struct ComputeArguments> { - using type = typename Fn::Arguments; - }; - -public: - struct SharedStorage { }; - - using Arguments = typename ComputeArguments>::type; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const&, Arguments const&) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90Compute() - : params() {} - - CUTLASS_HOST_DEVICE - Sm90Compute(Params const& params, SharedStorage const& shared_storage) - : params(params) {} - - Params const params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(Params const& params) - : params(params) {} - - Params const& params; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) { - return transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) CUTLASS_LAMBDA_FUNC_INLINE { - using ElementInput = typename cute::remove_cvref_t::Element; - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - return convert_input(frg_input); - }, - [&] (auto&&... cvt_frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - using ComputeOutput = ComputeFn>; - ComputeOutput compute_output{}; - - if constexpr (cute::is_same_v) { - using ElementComputeOutput = - typename cute::remove_cvref_t::Element; - using ConvertOutput = NumericArrayConverter; - ConvertOutput convert_output{}; - return convert_output(compute_output(cvt_frg_inputs...)); - } - else { - using ElementComputeOutput = - typename cute::remove_cvref_t::Element; - using ConvertOutput = NumericArrayConverter; - ConvertOutput convert_output{}; - return convert_output(compute_output(cvt_frg_inputs..., params)); - } - } - ); - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(params); - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Performance Optimized Specializations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// beta * C + Z -template < - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class InputScaleOp, // beta - class ElementSource, // C - class InputAddOp // Z -> -struct Sm90TreeVisitor< - Sm90Compute().is_zero())>>, - InputScaleOp, - Sm90SrcFetch, - InputAddOp -> : Sm90VisitorImpl< - InputScaleOp, - Sm90SrcFetch, - InputAddOp, - Sm90Compute - > -{ - using Impl = - Sm90VisitorImpl< - InputScaleOp, - Sm90SrcFetch, - InputAddOp, - Sm90Compute - >; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor( - Params const& params, - SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - auto const& scale_op = get<0>(Impl::ops); - auto const& added_op = get<2>(Impl::ops); - if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { - return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || - is_C_load_needed() || - added_op.is_producer_load_needed(); - } - else { - return is_C_load_needed() || added_op.is_producer_load_needed(); - } - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - auto const& scale_op = get<0>(Impl::ops); - auto const& src_op = get<1>(Impl::ops); - auto const& added_op = get<2>(Impl::ops); - return (not scale_op.is_zero() && src_op.is_C_load_needed()) || added_op.is_C_load_needed(); - } - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(bool is_C_load_needed, CallbacksImpl&& impl) - : is_C_load_needed(is_C_load_needed), CallbacksImpl(cute::forward(impl)) { } - - bool is_C_load_needed; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_added = get<2>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - - using ElementZ = typename decltype(frg_added)::Element; - using ConvertZ = NumericArrayConverter; - using ConvertI = NumericArrayConverter; - ConvertZ convert_Z{}; - ConvertI convert_I{}; - - Array frg_I = convert_Z(frg_added); - - if constexpr (!is_void_v) { - Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - - using ElementX = typename decltype(frg_scalar)::Element; - using ElementY = typename decltype(frg_source)::Element; - using ConvertX = NumericArrayConverter; - using ConvertY = NumericArrayConverter; - using ComputeI = multiply_add>; - ConvertX convert_X{}; - ConvertY convert_Y{}; - ComputeI compute_I{}; - - frg_I = compute_I(convert_X(frg_scalar), convert_Y(frg_source), frg_I); - } - - return convert_I(frg_I); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); - bool is_C_load_needed = this->is_C_load_needed(); - if (not is_C_load_needed) { - cute::clear(args.tCrC); - } - return ConsumerStoreCallbacks( - is_C_load_needed, std::move(callbacks_tuple)); - } -}; - -// ReLU with aux bit tensor dReLU/dZ -// Aux(i) = Z(i) >= 0 ? 1 : 0 -namespace detail { -// Placeholder node so we can retain standard EVT structure -template -struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { - struct SharedStorage {}; - - struct Arguments { - cutlass::uint1b_t* ptr_aux = nullptr; - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90ReLUAuxStore() { } - - CUTLASS_HOST_DEVICE - Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { } -}; -} // namespace detail - -// Specialization on the generic compute+aux EVT -template < - // Compute node - template class Activation, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - // Aux node - int Stages, - class EpilogueTile, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpR2S, - int Alignment, - bool EnableNullptr, - // Input node - class InputOp -> -struct Sm90TreeVisitor< - Sm90Compute, cutlass::epilogue::thread::ReLu> || - cute::is_same_v, cutlass::epilogue::thread::Clamp> || - cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU> >>, - Sm90TreeVisitor< - Sm90AuxStore< - Stages, - EpilogueTile, - cutlass::uint1b_t, - RoundStyle, - StrideMNL, - SmemLayoutAtom, - CopyOpR2S, - Alignment, - EnableNullptr - >, - InputOp - > -> : Sm90VisitorImpl< - Sm90VisitorImpl< - InputOp, - detail::Sm90ReLUAuxStore - >, - Sm90Compute - > -{ - using Impl = - Sm90VisitorImpl< - Sm90VisitorImpl< - InputOp, - detail::Sm90ReLUAuxStore - >, - Sm90Compute - >; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage) - : params(params_), Impl(params_, shared_storage) {} - - Params const& params; - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rAux, - GTensor&& tC_gAux, - CTensor tC_cAux, - ThrResidue residue_tC_cAux, - Params const& params, - CallbacksImpl&& impl) - : tC_rAux(cute::forward(tC_rAux)), - tC_gAux(cute::forward(tC_gAux)), - tC_cAux(tC_cAux), - residue_tC_cAux(residue_tC_cAux), - params(params), - CallbacksImpl(cute::forward(impl)) {} - - RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tC_cAux; - Params const& params; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - // Unpack callbacks + params - auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; - auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; - auto const& [params_input_aux, params_compute] = params; - auto const& [params_input, params_aux] = params_input_aux; - - // Visit the input node - Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); - - // Compute activation + aux - using ElementInput = typename decltype(frg_input)::Element; - using ConvertInput = NumericArrayConverter; - using ConvertAux = PackPredicates; - using ComputeOutput = Activation; - using ConvertOutput = NumericArrayConverter; - ConvertInput convert_input{}; - ComputeOutput relu{}; - ConvertAux convert_aux{}; - ConvertOutput convert_output{}; - - Array frg_compute = convert_input(frg_input); - bool frg_aux[FragmentSize]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - ElementCompute pre_relu = frg_compute[i]; - if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp> || - cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU>) { - frg_compute[i] = relu(frg_compute[i], params_compute); - } - else { - frg_compute[i] = relu(frg_compute[i]); - } - if constexpr (cute::is_same_v) { - uint32_t aux; - asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux - frg_aux[i] = static_cast(aux); - } else if constexpr (cute::is_same_v) { - uint32_t aux; - cutlass::half_t compute = frg_compute[i]; - asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux - frg_aux[i] = static_cast(aux); - } else { - frg_aux[i] = frg_compute[i] == pre_relu; - } - } - - static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); - Tensor tC_rAux_frg = recast(coalesce(tC_rAux(_,_,_,epi_m,epi_n))); // (EPI_V) - tC_rAux_frg(epi_v) = convert_aux(frg_aux); - - return convert_output(frg_compute); - } - - CUTLASS_DEVICE void - end() { - // Unpack callbacks + params - auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; - auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; - auto const& [params_input_aux, params_compute] = params; - auto const& [params_input, params_aux] = params_input_aux; - - // Visit the input node - callbacks_input.end(); - - // Nullptr is no-op - if constexpr (EnableNullptr) { - if (params_aux.ptr_aux == nullptr) { - return; - } - } - - // Compute vectorization - constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - // Copy vectorizes into byte-aligned stores - if constexpr (V > 1 && V % 8 == 0) { - using VecType = uint_bit_t; - Tensor tC_rAux_vec = recast(tC_rAux); - Tensor tC_gAux_vec = recast(tC_gAux); - Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); - } - // sub-byte vectorization, must serialize threads - else { - // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) - int lane_idx = canonical_lane_idx(); - Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - CUTLASS_PRAGMA_NO_UNROLL - for (int i = 0; i < NumThreadsPerWarp; ++i) { - if (lane_idx == i) { - copy_if(tC_pAux, tC_rAux, tC_gAux); - } - __syncwarp(); - } - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - // Unpack params - auto const& [params_input_aux, params_compute] = params; - auto const& [params_input, params_aux] = params_input_aux; - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(params_aux.ptr_aux); - Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params, cute::move(callbacks_impl)); - } -}; - -// Aux load for uint1b_t -template < - int Stages, - class EpilogueTile, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpS2R, - int Alignment, - bool EnableNullptr -> -struct Sm90AuxLoad< - Stages, - EpilogueTile, - cutlass::uint1b_t, - StrideMNL, - SmemLayoutAtom, - CopyOpS2R, - Alignment, - EnableNullptr -> { - static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); - - struct SharedStorage {}; - - struct Arguments { - cutlass::uint1b_t const* ptr_aux = nullptr; - cutlass::uint1b_t null_default = cutlass::uint1b_t(0); - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad() { } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage const&) - : params(params) { } - - Params const params; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ThrResidue residue_tC_cAux_, Params const& params_) - : tC_rAux(cute::forward(tC_rAux_)), - tC_gAux(cute::forward(tC_gAux_)), - tC_cAux(tC_cAux_), - residue_tC_cAux(residue_tC_cAux_), - params(params_) {} - - RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) - GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tC_cAux; - Params const& params; - - CUTLASS_DEVICE void - begin() { - if constexpr (decltype(cute::rank(tC_rAux))::value == 5) { - if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { - return; - } - } - - constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - if constexpr (V > 1) { - using VecType = uint_bit_t; - Tensor tC_gAux_vec = recast(tC_gAux); - Tensor tC_rAux_vec = recast(tC_rAux); - Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); - } - else { - Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux, tC_gAux, tC_rAux); - } - } - } - - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { - if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { - return; - } - } - - Tensor tC_pAux = cute::lazy::transform(tC_cAux(_,_,_,epi_m,epi_n), [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); - copy_if(tC_pAux, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); - } - } - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - using ElementRegister = typename remove_cvref_t::value_type; - if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { - return recast>(coalesce(tC_rAux))(epi_v); - } - else { - return recast>(coalesce(tC_rAux(_,_,_,epi_m,epi_n)))(epi_v); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(params.ptr_aux); - Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, args.epi_tile, args.tiled_copy, args.thread_idx); - - // If byte-unaligned vectorization, store in registers as uint32_t to reduce redundant pack+unpack instruction sequences - constexpr int V = decltype(max_common_vector(tC_gAux.layout(), make_layout(tC_gAux.shape())))::value; - Tensor tC_rAux = [&] () CUTLASS_LAMBDA_FUNC_INLINE { - if constexpr (V % 8 != 0) { - return make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - } else { - return make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - } - }(); - - if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { - fill(tC_rAux, params.null_default); - } - } - - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params); - } -}; - -// dReLU specialization -template< - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle -> -struct Sm90Compute< - cutlass::epilogue::thread::dReLU, - ElementOutput, - ElementCompute, - RoundStyle -> : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input, - Array const& frg_aux) { - using ConvertInput = NumericArrayConverter; - using ComputeOutput = cutlass::epilogue::thread::dReLU>; - using ConvertOutput = NumericArrayConverter; - ConvertInput convert_input{}; - ComputeOutput compute_output{}; - ConvertOutput convert_output{}; - - return convert_output(compute_output(convert_input(frg_input), frg_aux)); // don't convert frg_aux for dReLU - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp deleted file mode 100644 index 535d8b082d44ff796fe2efc4e1531b4a3dc2674c..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ /dev/null @@ -1,1492 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree load operations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/detail/helper_macros.hpp" - -#include "cute/tensor.hpp" -#include "sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Elementwise Fetch Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// returns accumulator -struct Sm90AccFetch : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - return frg_acc; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks{}; - } -}; - -// Split tree visitor fetches intermediate results from temporary accumulators -using Sm90SplitTreeFetch = Sm90AccFetch; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// returns C -template -struct Sm90SrcFetch : Sm90VisitorImpl<> { - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return is_C_load_needed(); - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return not is_void_v; - } - - CUTLASS_DEVICE bool - is_zero() const { - return is_void_v; - } - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(SrcTensor const& tCrC) - : tCrC(tCrC) {} - - SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - return recast>(tCrC)(epi_v); - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - // register type may differ from logical type so we can't assert matching types here - return ConsumerStoreCallbacks(args.tCrC); - } -}; - -// returns accumulator in Grouped Conv Wgrad -template -struct Sm90AccFetchGroupedWgrad : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - using GroupsPerTile = GroupsPerTile_; - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(int32_t thread_idx) - : thread_idx(thread_idx) { } - - int32_t thread_idx; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - - Array frg_acc_rst; - int warp_id = thread_idx / 32; - - // In Grouped Wgrad, only diagonal block data is valid and the others is wrong and useless. - // One block size is C/G x C/G. Note that C/G = Tile_N / GroupsPerTile. - // Copy diagonal block ACC into the first block Col which is the output tensor size Tile_M * C/G. - // Then we can store the valid output tensor tile directly. - if constexpr ( cute::is_same_v ) { - frg_acc_rst = frg_acc; - } - else if constexpr ( cute::is_same_v ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 16; i++) { - frg_acc_rst[i] = frg_acc[i + warp_id / 2 * 16]; - } - } - else if constexpr ( cute::is_same_v ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 8; i++) { - frg_acc_rst[i] = frg_acc[i + warp_id * 8]; - } - } - else if constexpr ( cute::is_same_v ) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; i++) { - frg_acc_rst[i] = frg_acc[i + warp_id * 8 + i / 2 * 4]; - } - } - - return frg_acc_rst; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(args.thread_idx); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Elementwise Load Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int Stages, - class EpilogueTile, - class Element, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpS2R, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -struct Sm90AuxLoad { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - - constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); - // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) - using SmemShapeTma = decltype(make_shape( - max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), - max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); - using SmemLayoutTma = decltype(tile_to_shape( - SmemLayoutAtom{}, SmemShapeTma{}, - cute::conditional_t, Step<_1,_2>>{} )); - using SmemLayout = decltype(tile_to_shape( - SmemLayoutTma{}, - make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - using CopyOpG2S = - SM90_TMA_LOAD - ; - - struct SharedStorage { - alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) - array_aligned smem_aux; - }; - - struct Arguments { - Element const* ptr_aux = nullptr; - Element null_default = Element(0); - StrideMNL dAux = {}; - }; - - struct Params { - using TMA_Aux = decltype(make_tma_copy( - CopyOpG2S{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), - take<0,2>(SmemLayoutTma{}))); - TMA_Aux tma_load_aux; - Element null_default = Element(0); - bool use_default = false; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto M_AUX = - size(M) - ; - Tensor tensor_aux = make_tensor(make_gmem_ptr(args.ptr_aux), make_layout(make_shape(M_AUX,N,L), append<3>(args.dAux, _0{}))); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(CopyOpG2S{}, tensor_aux, take<0,2>(SmemLayoutTma{})); - - bool use_default = false; - if constexpr (EnableNullptr) { - use_default = args.ptr_aux == nullptr; - } - - return Params{tma_load_aux, args.null_default, use_default}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad() { } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms), - smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr; - Element* smem_aux; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return true; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return (params_ptr->use_default && params_ptr->null_default == Element(0)); - } - - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& bGS_gAux, STensor&& bGS_sAux, Params const* params_ptr) - : bGS_gAux(cute::forward(bGS_gAux)), - bGS_sAux(cute::forward(bGS_sAux)), - params_ptr(params_ptr) {} - - GTensor bGS_gAux; // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - STensor bGS_sAux; // (TMA,TMA_M,TMA_N,PIPE) - Params const* params_ptr; - - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - if constexpr (EnableNullptr) { - if (params_ptr->use_default) { - return; - } - } - - if (issue_tma_load) { - // Increment the expected transaction bytes of the current stage's mbarrier by the subtile's byte-size - constexpr uint32_t copy_bytes = size(take<0,2>(SmemLayout{})) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA load - constexpr uint16_t mcast_mask = 0; - int load_pipe_index = load_iteration % Stages; - copy(params_ptr->tma_load_aux.with(*full_mbarrier_ptr, mcast_mask), - bGS_gAux(_,_,_,epi_m,epi_n), bGS_sAux(_,_,_,load_pipe_index)); - } - } - }; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - auto coord_shape = - make_coord(m, n, l) - ; - Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) - - Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) - - ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); - Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) - - return ProducerLoadCallbacks( - cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux, TiledS2R tiled_s2r, STensorS2R&& tSR_sAux, Params const* params_ptr) - : tC_rAux(cute::forward(tC_rAux)), - tiled_s2r(tiled_s2r), - tSR_sAux(cute::forward(tSR_sAux)), - params_ptr(params_ptr) { } - - TiledS2R tiled_s2r; - RTensor tC_rAux; // (CPY,CPY_M,CPY_N) - STensorS2R tSR_sAux; // (S2R,S2R_M,S2R_N,PIPE) - Params const* params_ptr; - - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if constexpr (EnableNullptr) { - if (params_ptr->use_default) { - fill(tC_rAux, params_ptr->null_default); - return; - } - } - - using RLayoutS2R = decltype(cute::layout(TiledS2R{}.get_slice(0).retile_S(RTensor{}))); - Tensor tSR_rAux = make_tensor(tC_rAux.data(), RLayoutS2R{}); // (S2R,S2R_M,S2R_N) - - int load_pipe_index = load_iteration % Stages; - copy(tiled_s2r, tSR_sAux(_,_,_,load_pipe_index), tSR_rAux); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) - - return tC_rAux_frg(epi_v); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - - Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); - Tensor tC_gAux = sm90_partition_for_epilogue(mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - - auto tiled_s2r = conditional_return( - make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), - make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) - ); - Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) - auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) - - return ConsumerStoreCallbacks( - cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); - } -}; - -template < - class Element, - class EpilogueTile, // Unused - class LayoutOrStrideMNL, - class SmemLayoutAtom, // Unused - class CopyOpS2R, // Unused - int Alignment, - bool EnableNullptr -> -struct Sm90AuxLoad< - 0, EpilogueTile, Element, LayoutOrStrideMNL, - SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr -> { - using ElementAux = Element; - using StrideMNL = cutlass::gemm::TagToStrideC_t; - - struct SharedStorage { }; - - struct Arguments { - Element const* ptr_aux = nullptr; - Element null_default = Element(0); - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad() { } - - CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template< - class GTensorG2R, - class RTensor, - class CTensorG2R, - class ProblemShapeMNL - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensorG2R&& tC_gAux, - RTensor&& tC_rAux, - CTensorG2R&& tC_cAux, - ProblemShapeMNL problem_shape_mnl, - Params const* params_ptr) - : tC_gAux(cute::forward(tC_gAux)), - tC_rAux(cute::forward(tC_rAux)), - tC_cAux(cute::forward(tC_cAux)), - problem_shape_mnl(problem_shape_mnl), - params_ptr(params_ptr) {} - - GTensorG2R tC_gAux; - RTensor tC_rAux; - CTensorG2R tC_cAux; - ProblemShapeMNL problem_shape_mnl; - Params const* params_ptr; - - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - if constexpr (EnableNullptr) { - if (params_ptr->ptr_aux == nullptr) { - fill(tC_rAux, params_ptr->null_default); - return; - } - } - constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - - Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); - Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); - - Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); - - copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - return recast>(tC_rAux)(epi_v); - } - }; - - template < - bool ReferenceSrc, - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - auto problem_shape_mnl = make_shape(M,N,L); - - // Gmem Tensor - Tensor mAux = make_tensor( - make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux - ); - Tensor tC_gAux = sm90_partition_for_epilogue( - mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - // Register Tensor - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); - - // Predication support - Tensor coordAux = make_identity_tensor(shape(mAux)); - Tensor tC_cAux = sm90_partition_for_epilogue( - coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - return ConsumerStoreCallbacks( - cute::move(tC_gAux), - cute::move(tC_rAux), - cute::move(tC_cAux), - problem_shape_mnl, - params_ptr - ); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Broadcast Load Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Scalar broadcast -// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors -template< - class Element, - class StrideMNL_ = Stride<_0,_0,_0>, - int BroadcastCount = 1, - template class ReductionFn = multiplies -> -struct Sm90ScalarBroadcast { - using StrideMNL = StrideMNL_; - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); - - struct SharedStorage { }; - - struct Arguments { - Element scalars[BroadcastCount] = {}; - Element const* scalar_ptrs[BroadcastCount] = {}; - StrideMNL dScalar[BroadcastCount] = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter *cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - // This must be called after update_scalar is called - CUTLASS_DEVICE bool - is_zero() const { - if (get<2>(params_ptr->dScalar[0]) == 0) { - // Only 1 batch - return scalar == Element(0); - } - else { - // multiple batch - if (valid_scalar == false) { - // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. - return params_ptr->scalar_ptrs[0] == nullptr; - } - else { - // Check whether each batch is ZERO or not. - return scalar == Element(0); - } - } - } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { - // Get the scalar for non-batched broadcast - if (size<2>(params_ptr->dScalar[0]) == 0) { - update_scalar(); - } - } - - Element scalar; - bool valid_scalar = false; - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - // Get the scalar for batched broadcast - if (size<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } - - return EmptyProducerLoadCallbacks{}; - } - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(Element scalar) - : scalar(scalar) {} - - Element scalar; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_scalar; - frg_scalar.fill(scalar); - - return frg_scalar; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } - - return ConsumerStoreCallbacks(scalar); - } - -private: - CUTLASS_DEVICE void - update_scalar(int l_coord = 0) { - valid_scalar = true; - int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); - - if (params_ptr->scalar_ptrs[0] != nullptr) { - scalar = params_ptr->scalar_ptrs[0][l_offset]; - } - else { - // batch stride is ignored for nullptr fallback - scalar = params_ptr->scalars[0]; - } - - // Do reduction over multiple broadcasts if necessary - ReductionFn reduction_fn; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < BroadcastCount; ++i) { - if (params_ptr->scalar_ptrs[i] != nullptr) { - int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); - scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); - } - else { - // batch stride is ignored for nullptr fallback - scalar = reduction_fn(scalar, params_ptr->scalars[i]); - } - } - } - - template - CUTLASS_DEVICE void - update_scalar(cute::tuple) { - // Only support multiple L-modes with fully-broadcast scalar - scalar = params_ptr->scalars[0]; - valid_scalar = true; - } -}; - -// Scalar broadcast -// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors -template< - class Element, - class StrideMNL_ = Stride<_0,_0,_0>, - int BroadcastCount = 1, - template class ReductionFn = multiplies -> -struct Sm90ScalarBroadcastPtrArray { - using StrideMNL = StrideMNL_; - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); - - struct SharedStorage { }; - - struct Arguments { - Element scalars[BroadcastCount] = {}; - Element const* scalar_ptrs[BroadcastCount] = {}; - Element const* const* scalar_ptr_arrays[BroadcastCount] = {}; - StrideMNL dScalar[BroadcastCount] = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter *cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - // producer load is needed if Element is not void - return !cute::is_void_v; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - // This must be called after update_scalar is called - CUTLASS_DEVICE bool - is_zero() const { - return scalar == Element(0); - } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcastPtrArray() { } - - CUTLASS_HOST_DEVICE - Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { - // Get the scalar for non-batched broadcast - if (size<2>(params_ptr->dScalar[0]) == 0) { - update_scalar(); - } - } - - Element scalar; - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - // Always refresh scalar with the current group index so per-group - // alpha/beta values (provided through pointer arrays) are loaded - // correctly even when the L-stride is zero. - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - - return EmptyProducerLoadCallbacks{}; - } - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(Element scalar) - : scalar(scalar) {} - - Element scalar; - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_scalar; - frg_scalar.fill(scalar); - - return frg_scalar; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - - return ConsumerStoreCallbacks(scalar); - } - -private: - CUTLASS_DEVICE void - update_scalar(int l_coord = 0) { - int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); - - if (params_ptr->scalar_ptr_arrays[0] != nullptr) { - // Pointer-array variant: each entry already points to the scalar of a group. - scalar = *(params_ptr->scalar_ptr_arrays[0][l_coord]); - } - else if (params_ptr->scalar_ptrs[0] != nullptr) { - // Strided pointer variant. - scalar = params_ptr->scalar_ptrs[0][l_offset]; - } - else { - // Literal fallback. - scalar = params_ptr->scalars[0]; - } - - // Do reduction over multiple broadcasts if necessary - ReductionFn reduction_fn; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < BroadcastCount; ++i) { - - if (params_ptr->scalar_ptr_arrays[i] != nullptr) { - scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][l_coord])); - } - else if (params_ptr->scalar_ptrs[i] != nullptr) { - int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); - scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); - } - else { - scalar = reduction_fn(scalar, params_ptr->scalars[i]); - } - } - } -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -template -[[deprecated("row broadcast only uses 0 stages")]] constexpr int -compute_row_broadcast_stages() { - return ceil_div(StagesC, size<1>(zipped_divide(make_layout(take<0,2>(CtaTileShapeMNK{})), EpilogueTile{}))) + 1; -} - -} - -// Row vector broadcast -template< - int Stages, - class CtaTileShapeMNK, - class ElementInput_, - class ElementCompute = cute::remove_pointer_t, - class StrideMNL_ = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v>, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -struct Sm90RowBroadcast { - using StrideMNL = StrideMNL_; - // Get base element input type. - using ElementInput = cute::remove_pointer_t; - // Check if input is an array of pointers. - static constexpr bool IsArrayOfPointers = is_same_v; - using PtrRowType = cute::conditional_t; - - static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); - - static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast - static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); - - struct SharedStorage { - array_aligned(CtaTileShapeMNK{})> smem; - }; - - struct Arguments { - PtrRowType ptr_row = nullptr; - ElementInput null_default = ElementInput(0); - StrideMNL dRow = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90RowBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), is_zero_(false), - smem(const_cast(shared_storage.smem.data())) { - auto const& [stride_M, stride_N, stride_L] = params.dRow; - // Nullptr default - if (EnableNullptr && params.ptr_row == nullptr) { - is_zero_ = params.null_default == ElementCompute(0); - } - // Dynamic non-batched scalar broadcast - else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { - if constexpr (!IsArrayOfPointers) { - is_zero_ = params.ptr_row[0] == ElementInput(0); - } - } - } - - Params params; - bool is_zero_ = false; - ElementInput *smem = nullptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return is_zero_; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, - GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, - SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, - Residue residue_cRow_, Params const& params_) - : tGS_gRow(tGS_gRow_) - , tGS_sRow(tGS_sRow_) - , tGS_cRow(tGS_cRow_) - , tiled_G2S(tiled_g2s_) - , tSR_sRow(tSR_sRow_) - , tSR_rRow(tSR_rRow_) - , residue_cRow(residue_cRow_) - , params(params_) { - } - - GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) - GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) - GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) - Tiled_G2S tiled_G2S; - - SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - Residue residue_cRow; // (m, n) - Params const& params; - - CUTLASS_DEVICE void - begin() { - bool is_nullptr = EnableNullptr && params.ptr_row == nullptr; - - Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); - Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride()); - - for (int i = 0; i < size(tGS_gRow_flt); ++i) { - if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - continue; // OOB of SMEM, - } - if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) { - tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load - } - else { - tGS_sRow_flt(i) = params.null_default; // fill OOB values so smem to RF load can issue without predication - } - } - } - - CUTLASS_DEVICE bool - begin_sync_needed() const { - return true; // Ensure visibility of async gmem to smem loads - } - - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - if (epi_m == 0) { // Assumes M-major subtile loop - Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = make_tensor_like(tSR_sRow_flt); - copy_aligned(tSR_sRow_flt, tSR_rRow_flt); - - constexpr int FrgSize = size(tSR_rRow_flt); - using FrgInput = Array; - using FrgCompute = Array; - using ConvertInput = NumericArrayConverter; - - Tensor tSR_rRow_input_frg = recast(coalesce(tSR_rRow_flt)); - Tensor tSR_rRow_compute_frg = recast(filter(tSR_rRow)); - ConvertInput convert_input{}; - - tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{})); - } - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_row; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); - } - - return frg_row; - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - using ThreadCount = decltype(size(args.tiled_copy)); - - auto layout_N = [&] () CUTLASS_LAMBDA_FUNC_INLINE { - auto shape_N = get<1>(args.problem_shape_mnkl); - if constexpr (IsDynamicBroadcast) { - auto stride_N = repeat_like(shape_N, int(0)); - if (get<1>(params.dRow) == bool(1)) { - stride_N = transform_leaf(compact_major(shape_N), - [] (auto const& stride) { return static_cast(stride); } - ); - } - return make_layout(shape_N, stride_N); - } - else { - return make_layout(shape_N); - } - }(); - - auto layout_M = make_layout(M, repeat_like(M, _0{})); - auto layout_L = make_layout(L, get<2>(params.dRow)); - ElementInput const* ptr_row = nullptr; - if constexpr(IsArrayOfPointers) { - if (!(EnableNullptr && params.ptr_row == nullptr)) { - ptr_row = params.ptr_row[l]; - } - } else { - ptr_row = params.ptr_row; - } - Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L)); - Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) - //// G2S: Gmem to Smem - auto tiled_g2s = make_tiled_copy(Copy_Atom{}, - Layout< Shape<_1, ThreadCount>, - Stride<_0, _1>>{}, - Layout<_1>{}); - auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); - Tensor tGS_gRow = thr_g2s.partition_S(gRow); - Tensor tGS_sRow = thr_g2s.partition_D(sRow); - - //// G2S: Coord - Tensor tGS_cRow = thr_g2s.partition_S(args.cD); - - //// S2R: Smem to Reg - Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) - - return ConsumerStoreCallbacks( - tGS_gRow, - tGS_sRow, - tGS_cRow, tiled_g2s, - tSR_sRow, - tSR_rRow, - args.residue_cD, - params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Column vector broadcast -template< - int Stages, - class CtaTileShapeMNK, - class ElementInput_, - class ElementCompute = cute::remove_pointer_t, - class StrideMNL_ = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v>, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -struct Sm90ColBroadcast { - using StrideMNL = StrideMNL_; - // Get base element input type. - using ElementInput = cute::remove_pointer_t; - // Check if input is an array of pointers. - static constexpr bool IsArrayOfPointers = is_same_v; - using PtrColType = cute::conditional_t; - - static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); - - static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast - static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast); - - // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem - struct SharedStorage { }; - - struct Arguments { - PtrColType ptr_col = nullptr; - ElementInput null_default = ElementInput(0); - StrideMNL dCol = {}; - }; - - struct Params { - PtrColType ptr_col = nullptr; - ElementCompute null_default = ElementCompute(0); - StrideMNL dCol = {}; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return {args.ptr_col, ElementCompute(args.null_default), args.dCol}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_zero() const { - return is_zero_; - } - - CUTLASS_HOST_DEVICE - Sm90ColBroadcast() { } - - CUTLASS_HOST_DEVICE - Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), is_zero_(false) { - auto const& [stride_M, stride_N, stride_L] = params.dCol; - // Nullptr default - if (EnableNullptr && params.ptr_col == nullptr) { - is_zero_ = params.null_default == ElementCompute(0); - } - // Dynamic non-batched scalar broadcast - else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { - if constexpr (!IsArrayOfPointers) { - is_zero_ = params.ptr_col[0] == ElementInput(0); - } - } - } - - Params params; - bool is_zero_; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_) - : tCgCol(tCgCol_), - tCrCol(tCrCol_), - tCcCol(tCcCol_), - residue_tCcCol(residue_tCcCol_), - params(params_) { - if (EnableNullptr && params.ptr_col == nullptr) { - fill(tCrCol, params.null_default); - } - } - - GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tCcCol; - Params const& params; - - CUTLASS_DEVICE void - begin() { - if (EnableNullptr && params.ptr_col == nullptr) { - return; - } - - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - Tensor tCgCol_flt = filter_zeros(tCgCol); - Tensor tCrCol_flt = make_tensor_like(filter_zeros(tCrCol)); - Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride()); - - constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - if constexpr (V > 1) { - using VecType = uint_bit_t>; - Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); - Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); - Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); - Tensor tCpCol_vec = cute::lazy::transform(tCcCol_vec, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); - copy_if(tCpCol_vec, tCgCol_vec, tCrCol_vec); - } - else { - Tensor tCpCol_flt = cute::lazy::transform(tCcCol_flt, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); - copy_if(tCpCol_flt, tCgCol_flt, tCrCol_flt); - } - - constexpr int FrgSize = size(tCrCol_flt); - using FrgInput = Array; - using FrgCompute = Array; - using ConvertInput = NumericArrayConverter; - - Tensor tCrCol_input_frg = recast(coalesce(tCrCol_flt)); - Tensor tCrCol_compute_frg = recast(filter(tCrCol)); - ConvertInput convert_input{}; - - tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{})); - } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); - } - - return frg_col; - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE { - auto shape_M = get<0>(args.problem_shape_mnkl); - if constexpr (IsDynamicBroadcast) { - auto stride_M = repeat_like(shape_M, int(0)); - if (get<0>(params.dCol) == bool(1)) { - stride_M = transform_leaf(compact_major(shape_M), - [] (auto const& stride) { return static_cast(stride); } - ); - } - return make_layout(shape_M, stride_M); - } - else { - return make_layout(shape_M); - } - }(); - - auto layout_N = make_layout(N, repeat_like(N, _0{})); - auto layout_L = make_layout(L, get<2>(params.dCol)); - ElementInput const* ptr_col = nullptr; - if constexpr(IsArrayOfPointers) { - if (!(EnableNullptr && params.ptr_col == nullptr)) { - ptr_col = params.ptr_col[l]; - } - } else { - ptr_col = params.ptr_col; - } - Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L)); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L)); - Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Batch matrix broadcast -// Only need to redefine this if we can multicast across cluster L -template < - int Stages, - class EpilogueTile, - class Element, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpS2R, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Fallback scalar broadcast for nullptr params -> -using Sm90MatrixBroadcast - = Sm90AuxLoad; - -namespace detail { - -template -struct IsScalarBroadcast { - static constexpr bool value = false; -}; - -template -struct IsScalarBroadcast(typename Operation::StrideMNL{})), Stride<_0,_0>>>> { - static constexpr bool value = true; -}; - -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp deleted file mode 100644 index 06ad8082e57cedf4d16aecdad8a995e838e1c93e..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ /dev/null @@ -1,1722 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" - -#include "cute/tensor.hpp" -#include "sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Elementwise Store Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int Stages, - class EpilogueTile, - class Element, - FloatRoundStyle RoundStyle, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpR2S, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Noop on nullptr params -> -struct Sm90AuxStore { - using ElementAux = Element; - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - - constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); - // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) - using SmemShapeTma = decltype(make_shape( - max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), - max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); - using SmemLayoutTma = decltype(tile_to_shape( - SmemLayoutAtom{}, SmemShapeTma{}, - cute::conditional_t, Step<_1,_2>>{} )); - using SmemLayout = decltype(tile_to_shape( - SmemLayoutTma{}, - make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), - cute::conditional_t, Step<_1,_2,_3>>{} )); - - struct SharedStorage { - alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) - array_aligned smem_aux; - }; - - struct Arguments { - Element* ptr_aux = nullptr; - StrideMNL dAux = {}; - }; - - struct Params { - using TMA_Aux = decltype(make_tma_copy( - SM90_TMA_STORE{}, - make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), - SmemLayoutTma{})); - TMA_Aux tma_store_aux; - bool is_nullptr = false; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - - bool is_nullptr = false; - if constexpr (EnableNullptr) { - is_nullptr = args.ptr_aux == nullptr; - } - - typename Params::TMA_Aux tma_store_aux; - if (not is_nullptr) { - Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); - tma_store_aux = make_tma_copy(SM90_TMA_STORE{}, tensor_aux, SmemLayoutTma{}); - } - - return {tma_store_aux, is_nullptr}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxStore() { } - - CUTLASS_HOST_DEVICE - Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms), - smem_aux(const_cast(shared_storage.smem_aux.data())) { } - - Params const* params_ptr; - Element* smem_aux; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template < - class RTensor, - class TiledR2S, - class STensorR2S, - class STensorS2G, - class GTensorS2G - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tC_rAux, - TiledR2S tiled_r2s, - STensorR2S&& tRS_sAux, - STensorS2G&& bSG_sAux, - GTensorS2G&& bSG_gAux, - Params const* params_ptr) - : tiled_r2s(tiled_r2s), - tC_rAux(cute::forward(tC_rAux)), - tRS_sAux(cute::forward(tRS_sAux)), - bSG_sAux(cute::forward(bSG_sAux)), - bSG_gAux(cute::forward(bSG_gAux)), - params_ptr(params_ptr) {} - - TiledR2S tiled_r2s; - RTensor tC_rAux; // (CPY,CPY_M,CPY_N) - STensorR2S tRS_sAux; // (R2S,R2S_M,R2S_N,PIPE) - STensorS2G bSG_sAux; // (S2G,S2G_M,S2G_N,PIPE) - GTensorS2G bSG_gAux; // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) - tC_rAux_frg(epi_v) = convert_input(frg_input); - - return frg_input; - } - - CUTLASS_DEVICE void - postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { - if constexpr (EnableNullptr) { - if (params_ptr->is_nullptr) { - return; - } - } - - using RLayoutR2S = decltype(cute::layout(TiledR2S{}.get_slice(0).retile_S(RTensor{}))); - Tensor tRS_rAux = make_tensor(tC_rAux.data(), RLayoutR2S{}); // (R2S,R2S_M,R2S_N) - - if (issue_smem_store) { - int store_pipe_index = store_iteration % Stages; - copy(tiled_r2s, tRS_rAux, tRS_sAux(_,_,_,store_pipe_index)); - } - } - - CUTLASS_DEVICE void - tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { - if constexpr (EnableNullptr) { - if (params_ptr->is_nullptr) { - return; - } - } - - if (issue_tma_store) { - // Issue the TMA store - int store_pipe_index = store_iteration % Stages; - copy(params_ptr->tma_store_aux, bSG_sAux(_,_,_,store_pipe_index), bSG_gAux(_,_,_,epi_m,epi_n)); - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - - Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - - Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) - Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - auto tiled_r2s = conditional_return( - make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), - make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) - ); - auto tRS_sAux = tiled_r2s.get_slice(args.thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) - - ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); - Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) - Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - - return ConsumerStoreCallbacks( - cute::move(tC_rAux), - tiled_r2s, - cute::move(tRS_sAux), - cute::move(bSG_sAux), - cute::move(bSG_gAux), - params_ptr); - } -}; - -template < - class Element, - class EpilogueTile, // Unused - FloatRoundStyle RoundStyle, - class LayoutOrStrideMNL, - class SmemLayoutAtom, // Unused - class CopyOpR2S, // Unused - int Alignment, - bool EnableNullptr -> -struct Sm90AuxStore< - 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, - SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr -> { - using ElementAux = Element; - using StrideMNL = cutlass::gemm::TagToStrideC_t; - - struct SharedStorage { }; - - struct Arguments { - Element* ptr_aux = nullptr; - StrideMNL dAux = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return cutlass::Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Sm90AuxStore() { } - - CUTLASS_HOST_DEVICE - Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) - : params_ptr(¶ms) { } - - Params const* params_ptr; - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template< - class GTensorR2G, - class RTensor, - class CTensorR2G, - class ProblemShapeMNL - > - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - GTensorR2G&& tC_gAux, - RTensor&& tC_rAux, - CTensorR2G&& tC_cAux, - ProblemShapeMNL problem_shape_mnl, - Params const* params_ptr) - : tC_gAux(cute::forward(tC_gAux)), - tC_rAux(cute::forward(tC_rAux)), - tC_cAux(cute::forward(tC_cAux)), - problem_shape_mnl(problem_shape_mnl), - params_ptr(params_ptr) {} - - GTensorR2G tC_gAux; - RTensor tC_rAux; - CTensorR2G tC_cAux; - ProblemShapeMNL problem_shape_mnl; - Params const* params_ptr; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); - tC_rAux_frg(epi_v) = convert_input(frg_input); - - return frg_input; - } - - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - if constexpr (EnableNullptr) { - if (params_ptr->ptr_aux == nullptr) { - return; - } - } - - constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; - constexpr int V = cute::min(Alignment, size(MCL)); - - Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); - Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); - - Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); - Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); - - copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); - } - }; - - template < - bool ReferenceSrc, - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - auto problem_shape_mnl = make_shape(M,N,L); - - // Gmem Tensor - Tensor mAux = make_tensor( - make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux - ); - Tensor tC_gAux = sm90_partition_for_epilogue( - mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - // Register Tensor - Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); - - // Predication support - Tensor coordAux = make_identity_tensor(shape(mAux)); - Tensor tC_cAux = sm90_partition_for_epilogue( - coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - - return ConsumerStoreCallbacks( - cute::move(tC_gAux), - cute::move(tC_rAux), - cute::move(tC_cAux), - problem_shape_mnl, - params_ptr - ); - - } - -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Reduction Store Operations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Scalar reduction -template < - template class RegReduceFn, - template class GmemReduceFn, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL = Stride<_0,_0,_0>, - bool EnableNullptr = true // Noop on nullptr params -> -struct Sm90ScalarReduction { -private: - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); - static constexpr bool IsAtomic = is_atomic>::value; - static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); - -public: - struct SharedStorage { }; - - struct Arguments { - ElementOutput* ptr_scalar = nullptr; - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dScalar = {}; - }; - - using Params = Arguments; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - #if !defined(CUTLASS_SKIP_REDUCTION_INIT) - if constexpr (IsAtomic) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); - if (args.ptr_scalar != nullptr) { - return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); - } - } - #endif - - return cutlass::Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90ScalarReduction() { } - - CUTLASS_HOST_DEVICE - Sm90ScalarReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params const params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks( - int l_coord, - CTensor tCcScalar, - ThrResidue residue_tCcScalar, - Params const& params) - : scalar(params.reduction_identity), - l_coord(l_coord), - tCcScalar(tCcScalar), - residue_tCcScalar(residue_tCcScalar), - params(params) {} - - ElementCompute scalar; - int l_coord; - CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ThrResidue residue_tCcScalar; - Params params; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - if constexpr (EnableNullptr) { - if (params.ptr_scalar == nullptr) { - return frg_input; - } - } - - using ConvertInput = NumericArrayConverter; - using ReduceInput = RegReduceFn; - ConvertInput convert_input{}; - ReduceInput reduce_input{}; - - Array frg_I = convert_input(frg_input); - Tensor tCcScalar_mn = tCcScalar(_,_,_,epi_m,epi_n); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_tCcScalar)) { - scalar = reduce_input(scalar, frg_I[i]); - } - } - - return frg_input; - } - - CUTLASS_DEVICE void - end() { - if constexpr (EnableNullptr) { - if (params.ptr_scalar == nullptr) { - return; - } - } - - using ConvertI = NumericConverter; - using ReduceInput = GmemReduceFn; - - ConvertI convert_I{}; - ReduceInput reduce_input{}; - - ElementOutput* ptr_scalar = params.ptr_scalar + l_coord * get<2>(params.dScalar); - reduce_input(ptr_scalar, convert_I(scalar)); - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_tCcD, params); - } - -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Row vector reduction -template < - template class RegReduceFn, - template class ShuffleReduceFn, - template class GmemReduceFn, - int Stages, - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL = Stride<_0,_1,_0>, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true, // Noop on nullptr params - // If this is false, ptr_row is assumed to point to a compact n-major (ceil_div(M,CTA_M), round_nearest(N,CTA_N), L) - // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput - bool FinalReduction = true, - // False means skip OOB predication if OOB inputs are known to be the reduction identity - bool VisitCheckOOB = true, - // Indicate the parameter order when calling RegReduceFn - // Seq length equals the number of RegReduceFn parameters - // No.0 represents tCrRow; No.1 and subsequent numbers sequentially represent frg_inputs in `visit` - class RegReduceSeq = cute::seq<0, 1> -> -struct Sm90RowReduction { -private: - static_assert(Stages == 0, "Smem usage not supported yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - static constexpr bool IsAtomic = is_atomic>::value; - static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); - -public: - struct SharedStorage { }; - - struct Arguments { - void* ptr_row = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dRow = {}; - }; - - struct Params { - void* ptr_row = nullptr; - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dRow = {}; - ElementCompute* reduction_buffer = nullptr; - int* tile_counters = nullptr; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - ElementCompute* reduction_buffer; - int* tile_counters = nullptr; - if constexpr (IsAtomic) { - reduction_buffer = nullptr; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - reduction_buffer = reinterpret_cast(workspace); - tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - } - else { - reduction_buffer = reinterpret_cast(args.ptr_row); - } - - return { - args.ptr_row, - args.reduction_identity, - args.dRow, - reduction_buffer, - tile_counters - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - if constexpr (IsAtomic || not FinalReduction) { - return 0; - } - - size_t workspace_size = 0; - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - // Increment by size of reduction buffer - workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - // Align and increment by size of tile counters - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += cute::ceil_div(size<>(N), tile_N) * sizeof(int); - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - if constexpr (IsAtomic) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); - if (args.ptr_row != nullptr) { - return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); - } - return Status::kSuccess; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); - return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); - } - else { - return Status::kSuccess; - } - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90RowReduction() { } - - CUTLASS_HOST_DEVICE - Sm90RowReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) - : args_tuple(cute::forward(args_tuple)), - params(params) {} - - ArgsTuple args_tuple; - Params const& params; - bool do_final_reduction = false; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) { - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - return cute::get<0>(cute::make_tuple(frg_inputs...)); - } - } - - auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; - Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); - Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); - - if constexpr (VisitCheckOOB) { - using ReduceInput = RegReduceFn; - ReduceInput reduce_input{}; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { - ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); - tCrRow_vmn = transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) { - return ElementCompute(frg_input[i]); - }, - [&] (auto&&... cvt_frg_inputs) { - auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn, cvt_frg_inputs...); - return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); - }); - } - } - } - else { - constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); - using ReduceInput = RegReduceFn>; - ReduceInput reduce_input{}; - Tensor tCrRow_mn_frg = recast>(tCrRow_mn); - - constexpr int RegFragArraySize = FragmentSize / RegFragSize; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < RegFragArraySize; ++i) { - Array& tCrRow_vmn_frg = tCrRow_mn_frg(epi_v * RegFragArraySize + i); - tCrRow_vmn_frg = transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) { - using ElementInput = typename cute::remove_cvref_t::Element; - using ConvertInput = NumericArrayConverter; - using RegFragArr = Array, RegFragArraySize>; - ConvertInput convert_input{}; - return convert_input(reinterpret_cast(frg_input)[i]); - }, - [&] (auto&&... cvt_frg_inputs) { - auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn_frg, cvt_frg_inputs...); - return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); - }); - } - } - return cute::get<0>(cute::make_tuple(frg_inputs...)); - } - - template - CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - if (not is_last_iteration) { - return; - } - - auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; - auto [m, n, k, l] = tile_coord_mnkl; - constexpr bool ReferenceSrc = decltype(ref_src)::value; - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - return; - } - } - - // fully OOB CTA in partially OOB cluster - if (not elem_less(cRow(_0{},_0{}), residue_cRow)) { - return; - } - - int lane_m = get<0>(lane_mn); - [[maybe_unused]] bool is_reduced_lane = lane_m == 0; - - // - // 1. Warp shuffle reduction - // - using FragmentShuffle = Array; - Tensor tCrRow_frg = recast(filter(tCrRow)); - using ReduceShuffle = ShuffleReduceFn; - ReduceShuffle reduce_shuffle{}; - - auto FrgSizePerLaneM = size(tCrRow_frg) / size<0>(lane_layout_MN); - constexpr bool SwapShuffle = FrgSizePerLaneM > 0; - - // - // Swap Shuffle - // - // The normal way to reduction among threads: - // use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. - // After each step of reduction, a half of threads won't work in the following steps. - // That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). - // - // To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, - // we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. - // After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. - // We can recursively do this until the problem size is 1. - // - if constexpr (SwapShuffle) { // for a NxN matrix to be reduced among N threads as a 1XN vectors - Tensor tCrRow_frg_ = logical_divide(tCrRow_frg, FrgSizePerLaneM); // (FrgSizePerLaneM, M) - CUTLASS_PRAGMA_UNROLL - for (int m = size<1>(tCrRow_frg_) / 2; m > 0; m /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int r = 0; r < m; ++r) { - auto frg_A = tCrRow_frg_(_,r); - auto frg_B = tCrRow_frg_(_,r + m); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < size(frg_A); ++v) { - // Step1: swap - if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second - cutlass::swap(frg_A(v), frg_B(v)); - } - - // Step2: shuffle - uint64_t frg_shfl = reinterpret_cast(frg_A(v)); - // each half of threads get a half of data from the other half of threads - frg_shfl = __shfl_xor_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(m, _0{})); - - // Step3: reduction - frg_A(v) = reduce_shuffle(frg_B(v), reinterpret_cast(frg_shfl)); - } - } - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { - uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); - frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); - tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); - } - } - } - - // - // 2. Atomic reduction - // - if constexpr (IsAtomic) { - // Filter so we don't issue redunant copies over stride-0 modes - Tensor tCrRow_flt = filter_zeros(tCrRow); - Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); - auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); - - Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); - Tensor tCgRow_flt = filter_zeros(tCgRow); - // NOTE: atomic reduction is performed in the output type - using ConvertOutput = NumericConverter; - using ReduceOutput = GmemReduceFn; - ConvertOutput convert_output{}; - ReduceOutput reduce_output{}; - - if constexpr (SwapShuffle) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FltFrgSizePerLaneM; ++i) { - int idx = lane_m * FltFrgSizePerLaneM + i; - // Only care about OOB for N mode - if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) { - reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i))); - } - } - } - else { - if (is_reduced_lane) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrRow_flt); ++i) { - if (elem_less(tCcRow_flt(i), residue_tCcRow)) { - reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); - } - } - } - } - sync_fn(); - } - - // - // 2. One warp in M, skip threadblock smem reduction - // - else if constexpr (decltype(size<0>(warp_layout_MN))::value <= 1) { - // Dump warp reduction to gmem workspace - using ElementGmem = cute::conditional_t; - Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); - - if constexpr (SwapShuffle) { - Tensor tCrRow_flt = filter(tCrRow); - Tensor tCgBuf_flt = recast(filter(tCgBuf)); - auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); - Tensor tCgBuf_flt_ = logical_divide(tCgBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - copy_aligned(tCrRow_flt_(_,_0{}), tCgBuf_flt_(_,lane_m)); - } - else { - if (is_reduced_lane) { - copy_aligned(tCrRow, recast(tCgBuf)); - } - } - sync_fn(); - } - - // - // 2. Multiple warps in M, do threadblock smem reduction - // - else { - Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); - static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= - decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), - "smem reduction buffer not large enough, use a larger epilogue tile"); - sync_fn(); - - // Dump warp reduction to smem workspace - Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); - - if constexpr (SwapShuffle) { - Tensor tCrRow_flt = filter(tCrRow); - Tensor tCsBuf_flt = filter(tCsBuf); - auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); - Tensor tCsBuf_flt_ = logical_divide(tCsBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) - copy_aligned(tCrRow_flt_(_,_0{}), tCsBuf_flt_(_,lane_m)); - } - else { - if (is_reduced_lane) { - copy_aligned(tCrRow, tCsBuf); - } - } - sync_fn(); - - constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); - using FragmentSmem = Array; - using VectorSmem = uint_bit_t>; - using ReduceSmem = GmemReduceFn; - ReduceSmem reduce_smem{}; - - Tensor sBuf_frg = recast(filter_zeros(sBuf)); - Tensor sBuf_vec = recast(filter_zeros(sBuf)); - constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; - - constexpr int RowNum = decltype(size<0>(warp_layout_MN))::value; - using FragmentSmemArray = Array; - - // Do the threadblock smem reduction - using VectorGmem = cute::conditional_t; - Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); - CUTLASS_PRAGMA_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { - FragmentSmemArray frg_smem; - - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = 0; reduction_rows < RowNum; ++reduction_rows) { - int FragsCurrRows = reduction_rows * FragsPerRow; - frg_smem[reduction_rows] = sBuf_frg(FragsCurrRows + frg_idx); - } - - CUTLASS_PRAGMA_UNROLL - for (int reduction_rows = RowNum / 2; reduction_rows > 0; reduction_rows /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int row_idx = 0; row_idx < reduction_rows; ++row_idx) { - frg_smem[row_idx] = reduce_smem(frg_smem[row_idx], frg_smem[row_idx + reduction_rows]); - } - } - gBuf_vec(frg_idx) = reinterpret_cast(frg_smem[0]); - } - sync_fn(); - } - - // - // 3. Increment atomic counters to signal final gmem reduction - // - if constexpr (not IsAtomic && FinalReduction) { - // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); - sync_fn(); - // Collective thread 0 increments atomic tile counter and copies value to smem - int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); - if (thread_idx == 0) { - *prev_tile_count = atomicAdd(¶ms.tile_counters[n], 1); - } - sync_fn(); - // Broadcast tile count to other threads in CTA and determine final reduction status - do_final_reduction = *prev_tile_count == size<2>(gBuf_ml) * size<3>(gBuf_ml) - 1; - sync_fn(); - } - } - - CUTLASS_DEVICE void - end() { - // - // 4. Do final gmem reduction if necessary - // - if constexpr (not IsAtomic && FinalReduction) { - if (not do_final_reduction) { - return; - } - - auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; - - using ReduceOutput = GmemReduceFn; - using ConvertOutput = NumericConverter; - ReduceOutput reduce_output{}; - ConvertOutput convert_output{}; - - // Reduction over batches - if (size<2>(stride(gRow_l)) == 0) { - CUTLASS_PRAGMA_NO_UNROLL - for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { - Tensor tRgBuf_ml = gBuf_ml(_0{},n,_,_); - ElementCompute output = tRgBuf_ml(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int ml = 1; ml < size(tRgBuf_ml); ++ml) { - output = reduce_output(output, tRgBuf_ml(ml)); - } - if (elem_less(cRow(_0{},n), residue_cRow)) { - gRow_l(_0{},n,_0{}) = convert_output(output); - } - } - } - // No reduction over batches - else { - CUTLASS_PRAGMA_NO_UNROLL - for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { - bool do_store = elem_less(cRow(_0{},n), residue_cRow); - CUTLASS_PRAGMA_NO_UNROLL - for (int l = 0; l < size<3>(gBuf_ml); ++l) { - Tensor tRgBuf_m = gBuf_ml(_0{},n,_,l); - ElementCompute output = tRgBuf_m(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int m = 1; m < size(tRgBuf_m); ++m) { - output = reduce_output(output, tRgBuf_m(m)); - } - if (do_store) { - gRow_l(_0{},n,l) = convert_output(output); - } - } - } - } - - } - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - Layout ref_layout_MN = [&] () { - auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); - if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } - else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } - }(); // tile_mn -> tv_idx - - // Get the MN layout + coord of lanes to determine shuffle reduction iterations - using _W = Int; - Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx - Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx - Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx - Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn - int lane_idx = canonical_lane_idx(); - auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); - - // Get the MN layout + coord of warps to determine smem reduction iterations - Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx - Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx - Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx - Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn - - int warp_idx = args.thread_idx / NumThreadsPerWarp; - auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); - - // Partition output gmem and register tensors - auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); // (M,N,L) - Tensor gRow_l = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) - Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gRow_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - fill(tCrRow, params.reduction_identity); - - // Partition gmem+smem reduction buffer tensors - Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_0{}, _1{})); - auto block_shape = ceil_div(make_shape(M,N,L), shape(gBuf_layout)); // (M_CNT, N_CNT, L_CNT) - - // Let the M_CNT (the num of partial reduction results) become the outer mode - Layout block_layout = make_layout(block_shape, make_stride(get<1>(block_shape), _1{}, get<0>(block_shape) * get<1>(block_shape))); - Layout mBuf_layout = blocked_product(gBuf_layout, block_layout); - Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) - Tensor gBuf_ml = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(_,n,_)); // (CTA_M,CTA_N,REST_M,L) - Layout sBuf_layout = blocked_product(gBuf_layout, // (CTA_M,CTA_N,WARPS_M) - make_layout(make_shape(_1{},_1{},size<0>(warp_layout_MN)))); - - auto args_tuple = make_tuple( - bool_constant{}, cute::move(tCrRow), args.tCcD, gRow_l, args.cD, gBuf_ml, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); - return ConsumerStoreCallbacks(cute::move(args_tuple), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Col vector reduction -template < - template class RegReduceFn, - template class ShuffleReduceFn, - template class GmemReduceFn, - int Stages, - class CtaTileShapeMNK, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL = Stride<_1,_0,_0>, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true, // Noop on nullptr params - // If this is false, ptr_col is assumed to point to a compact m-major (round_nearest(M,CTA_M), ceil_div(N,CTA_N), L) - // tensor of ElementCompute. It is the user's responsibility to reduce this to a (M, L) tensor of ElementOutput - bool FinalReduction = true, - // False means skip OOB predication if OOB inputs are known to be the reduction identity - bool VisitCheckOOB = true -> -struct Sm90ColReduction { -private: - static_assert(Stages == 0, "Smem usage not supported yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static - static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); - static constexpr bool IsAtomic = is_atomic>::value; - static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); - -public: - struct SharedStorage { }; - - struct Arguments { - void* ptr_col = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dCol = {}; - }; - - struct Params { - void* ptr_col = nullptr; - ElementCompute reduction_identity = ElementCompute(0); - StrideMNL dCol = {}; - ElementCompute* reduction_buffer = nullptr; - int* tile_counters = nullptr; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - ElementCompute* reduction_buffer; - int* tile_counters = nullptr; - if constexpr (IsAtomic) { - reduction_buffer = nullptr; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - reduction_buffer = reinterpret_cast(workspace); - tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - } - else { - reduction_buffer = reinterpret_cast(args.ptr_col); - } - - return { - args.ptr_col, - args.reduction_identity, - args.dCol, - reduction_buffer, - tile_counters - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return true; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - if constexpr (IsAtomic || not FinalReduction) { - return 0; - } - - size_t workspace_size = 0; - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - - // Increment by size of reduction buffer - workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - // Align and increment by size of tile counters - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - if constexpr (IsAtomic) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); - if (args.ptr_col != nullptr) { - return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); - } - return Status::kSuccess; - } - else if constexpr (FinalReduction) { - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - - int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); - return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); - } - else { - return Status::kSuccess; - } - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90ColReduction() { } - - CUTLASS_HOST_DEVICE - Sm90ColReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) - : args_tuple(cute::forward(args_tuple)), - params(params) {} - - ArgsTuple args_tuple; - Params const& params; - bool do_final_reduction = false; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - if constexpr (EnableNullptr) { - if (params.ptr_col == nullptr) { - return frg_input; - } - } - - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - - using ConvertInput = NumericArrayConverter; - using ReduceInput = RegReduceFn; - ConvertInput convert_input{}; - ReduceInput reduce_input{}; - - Array frg_I = convert_input(frg_input); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - if (!VisitCheckOOB || elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_tCcCol)) { - ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); - tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); - } - } - - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - if (not is_last_iteration) { - return; - } - - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; - auto [m, n, k, l] = tile_coord_mnkl; - constexpr bool ReferenceSrc = decltype(ref_src)::value; - - // Runtime nullptr is noop - if constexpr (EnableNullptr) { - if (params.ptr_col == nullptr) { - return; - } - } - - // fully OOB CTA in partially OOB cluster - if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { - return; - } - - // - // 1. Warp shuffle reduction - // - using FragmentShuffle = Array; - using ReduceShuffle = ShuffleReduceFn; - ReduceShuffle reduce_shuffle{}; - Tensor tCrCol_frg = recast(filter(tCrCol)); - CUTLASS_PRAGMA_UNROLL - for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int frg_idx = 0; frg_idx < size(tCrCol_frg); ++frg_idx) { - uint64_t frg_shfl = reinterpret_cast(tCrCol_frg(frg_idx)); - frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(_0{},reduction_cols)); - tCrCol_frg(frg_idx) = reduce_shuffle(tCrCol_frg(frg_idx), reinterpret_cast(frg_shfl)); - } - } - bool is_reduced_lane = get<1>(lane_mn) == 0; - - // - // 2. Atomic reduction - // - if constexpr (IsAtomic) { - // Filter so we don't issue redunant copies over stride-0 modes - Tensor tCrCol_flt = filter_zeros(tCrCol); - Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); - - Tensor tCgCol = sm90_partition_for_epilogue(gCol_l(_,_,l), epi_tile, tiled_copy, thread_idx); - Tensor tCgCol_flt = filter_zeros(tCgCol); - - // NOTE: atomic reduction is performed in the output type - using ConvertOutput = NumericConverter; - using ReduceOutput = GmemReduceFn; - ConvertOutput convert_output{}; - ReduceOutput reduce_output{}; - - if (is_reduced_lane) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrCol_flt); ++i) { - if (elem_less(tCcCol_flt(i), residue_tCcCol)) { - reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); - } - } - } - sync_fn(); - } - - // - // 2. One warp in N, skip threadblock smem reduction - // - else if constexpr (decltype(size<1>(warp_layout_MN))::value <= 1) { - // Dump warp reduction to gmem workspace - using ElementGmem = cute::conditional_t; - Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - copy_aligned(tCrCol, recast(tCgBuf)); - } - sync_fn(); - } - - // - // 2. Multiple warps in N, do threadblock smem reduction - // - else { - Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); - static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= - decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), - "smem reduction buffer not large enough, use a larger epilogue tile"); - sync_fn(); - - // Dump warp reduction to smem workspace - Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); - if (is_reduced_lane) { - copy_aligned(tCrCol, tCsBuf); - } - sync_fn(); - - constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); - using FragmentSmem = Array; - using VectorSmem = uint_bit_t>; - using ReduceSmem = GmemReduceFn; - ReduceSmem reduce_smem{}; - - Tensor sBuf_frg = recast(filter_zeros(sBuf)); - Tensor sBuf_vec = recast(filter_zeros(sBuf)); - constexpr int FragsPerCol = decltype(size<0>(sBuf_frg))::value; - - // Do the threadblock smem reduction - CUTLASS_PRAGMA_UNROLL - for (int reduction_cols = size<1>(warp_layout_MN) / 2; reduction_cols > 1; reduction_cols /= 2) { - int FragsPerReduction = reduction_cols * FragsPerCol; - CUTLASS_PRAGMA_NO_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); - sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); - } - sync_fn(); - } - - // Do final smem reduction and dump to gmem workspace - using VectorGmem = cute::conditional_t; - Tensor gBuf_vec = recast(filter(gBuf_nl(_,_,n,l))); - CUTLASS_PRAGMA_NO_UNROLL - for (int frg_idx = thread_idx; frg_idx < FragsPerCol; frg_idx += size(tiled_copy)) { - FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerCol)); - gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); - } - sync_fn(); - } - - // - // 3. Increment atomic counters to signal final gmem reduction - // - if constexpr (not IsAtomic && FinalReduction) { - // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); - sync_fn(); - // Collective thread 0 increments atomic tile counter and copies value to smem - int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); - if (thread_idx == 0) { - *prev_tile_count = atomicAdd(¶ms.tile_counters[m], 1); - } - sync_fn(); - // Broadcast tile count to other threads in CTA and determine final reduction status - do_final_reduction = *prev_tile_count == size<2>(gBuf_nl) * size<3>(gBuf_nl) - 1; - sync_fn(); - } - } - - CUTLASS_DEVICE void - end() { - // - // 4. Do final gmem reduction if necessary - // - if constexpr (not IsAtomic && FinalReduction) { - if (not do_final_reduction) { - return; - } - - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; - - using ReduceOutput = GmemReduceFn; - using ConvertOutput = NumericConverter; - ReduceOutput reduce_output{}; - ConvertOutput convert_output{}; - - // Reduction over batches - if (size<2>(stride(gCol_l)) == 0) { - CUTLASS_PRAGMA_NO_UNROLL - for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { - Tensor tRgBuf_nl = gBuf_nl(m,_0{},_,_); - ElementCompute output = tRgBuf_nl(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { - output = reduce_output(output, tRgBuf_nl(nl)); - } - if (elem_less(cCol(m,_0{}), residue_cCol)) { - gCol_l(m,_0{},_0{}) = convert_output(output); - } - } - } - // No reduction over batches - else { - CUTLASS_PRAGMA_NO_UNROLL - for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { - bool do_store = elem_less(cCol(m,_0{}), residue_cCol); - CUTLASS_PRAGMA_NO_UNROLL - for (int l = 0; l < size<3>(gBuf_nl); ++l) { - Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); - ElementCompute output = tRgBuf_n(_0{}); - CUTLASS_PRAGMA_NO_UNROLL - for (int n = 1; n < size(tRgBuf_n); ++n) { - output = reduce_output(output, tRgBuf_n(n)); - } - if (do_store) { - gCol_l(m,_0{},l) = convert_output(output); - } - } - } - } - - } - } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - Layout ref_layout_MN = [&] () { - auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); - if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } - else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } - }(); // tile_mn -> tv_idx - - // Get the MN layout + coord of lanes to determine shuffle reduction iterations - using _W = Int; - Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx - Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx - Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx - Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn - int lane_idx = canonical_lane_idx(); - auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); - - // Get the MN layout + coord of warps to determine smem reduction iterations - Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx - Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx - Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx - Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn - int warp_idx = args.thread_idx / NumThreadsPerWarp; - auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); - - // Partition output gmem and register tensors - auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); // (M,N,L) - Tensor gCol_l = local_tile(mCol, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gCol_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - fill(tCrCol, params.reduction_identity); - - // Partition gmem+smem reduction buffer tensors - Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_1{}, _0{})); - Layout mBuf_layout = blocked_product(gBuf_layout, make_layout(ceil_div(make_shape(M,N,L), shape(gBuf_layout)))); - Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) - Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) - Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) - - auto args_tuple = make_tuple( - bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); - return ConsumerStoreCallbacks(std::move(args_tuple), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Batch matrix reduction -template < - int Stages, - class EpilogueTile, - class Element, - class StrideMNL, - class CopyOpR2S, - class SmemLayoutAtom, - int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Noop on nullptr params -> -struct Sm90MatrixReduction; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp deleted file mode 100644 index 93720f8d3d71f3f4759463b5d40e604313b7e3a4..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ /dev/null @@ -1,1149 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree operation base implementation to enable composable fusions - for the sm90 TMA warp-specialized (ws) epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" -#include "cutlass/detail/helper_macros.hpp" - -#include "cute/tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -using namespace cute; -using cute::tuple; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partitioning Helpers -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class CtaTileMN, - class EpilogueTile, - class TiledCopy -> -CUTLASS_HOST_DEVICE -constexpr auto -sm90_partition_for_epilogue( - CtaTileMN cT, // (CTA_M,CTA_N,...) - EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) - TiledCopy tiled_copy, - int thread_idx) { - ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); - Tensor cT_epi = flat_divide(cT, epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) - if constexpr (ReferenceSrc) { - return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) - } - else { - return thread_copy.partition_D(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) - } -} - -template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class Engine, class LayoutMNL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy -> -CUTLASS_HOST_DEVICE -constexpr auto -sm90_partition_for_epilogue( - Tensor mT, // (M,N,L) - TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) - TileCoordMNKL tile_coord_mnkl, // (m,n,k,l) - EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) - TiledCopy tiled_copy, - int thread_idx) { - auto [m, n, k, l] = tile_coord_mnkl; - auto coord_shape = - make_coord(m, n, l) - ; - Tensor cT = local_tile(mT, take<0,2>(tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) - Tensor tCcT = - sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - - return tCcT; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Visitor Implementation -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// -// Producer load callbacks, called by the epilogue load warp. -// Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation -// Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but -// are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead -// If this is non-empty, is_producer_load_needed must be true. -// -template -struct ProducerLoadCallbacksImpl { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of the subtile load loop - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Entry of the subtile load loop. Aux loads usually performed here - // Upon entry the producer acquire of the current subtile lock has completed. - // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations - CUTLASS_DEVICE void - step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); - } - ); - } - - // Exit of the subtile load loop. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } -}; - - -// -// Consumer store callbacks, called by the epilogue store warps. -// All operations must redefine this, with optional inheritance from this empty implementation. -// -template -struct ConsumerStoreCallbacksImpl { - // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables - CallbacksTuple callbacks_tuple; - - // Before entry of subtile store loop. Gmem broadcasts usually performed here. - CUTLASS_DEVICE void - begin() { - for_each(callbacks_tuple, - [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin(); - } - ); - } - - // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes - CUTLASS_DEVICE bool - begin_sync_needed() const { - return cute::apply(callbacks_tuple, - [] (auto const&... callbacks) { - return (false || ... || callbacks.begin_sync_needed()); - } - ); - } - - // Start of subtile store iteration - CUTLASS_DEVICE void - begin_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.begin_loop(epi_m, epi_n); - } - ); - } - - // Before visit callback. Smem broadcasts usually performed here. - // Upon entry, all producer loads for this subtile are completed and visible. - CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); - } - ); - } - - // Perform the fused elementwise computation - template - CUTLASS_DEVICE auto // returns an Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) // depends on the N-naryness of the op - = delete; // Must be implemented for each operation - - // After visit call. Smem reductions usually performed here - // reduction_buffer is an arbitrary smem tensor that can be used for workspace - // It is each nodes reponsibility to assert that this buffer is sufficiently sized - // and to ensure that this buffer is no longer needed upon callback exit - // i.e. results are synchronized and no longer in the reduction buffer - // - // visit_results is a rmem tensor that contains the results of visit() for an entire - // on the current epilogue subtile - template - CUTLASS_DEVICE void - reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); - } - ); - } - - // After reduce call, before smem async fence. Smem stores usually performed here. - // Upon exit, all smem stores for TMA must have been issued - CUTLASS_DEVICE void - postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); - } - ); - } - - // After smem async fence, before TMA store commit. Aux stores usually performed here - // Upon exit, all TMA stores for this subtile must have been issued - // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores - // other gmem stores can be placed in the reduce or postreduce entry points - CUTLASS_DEVICE void - tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); - } - ); - } - - // End of subtile store iteration - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end_loop(epi_m, epi_n); - } - ); - } - - // Exit of subtile store loop. Gmem reductions usually performed here. - CUTLASS_DEVICE void - end() { - for_each(callbacks_tuple, - [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - callbacks.end(); - } - ); - } -}; - -template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma, - class EpilogueTile -> -struct ProducerLoadArgs { - ProblemShapeMNKL problem_shape_mnkl; - TileShapeMNK tile_shape_mnk; - TileCoordMNKL tile_coord_mnkl; - TiledMma tiled_mma; - EpilogueTile epi_tile; - int thread_idx; - - CUTLASS_DEVICE - ProducerLoadArgs( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - EpilogueTile epi_tile, - int thread_idx) - : problem_shape_mnkl(problem_shape_mnkl), - tile_shape_mnk(tile_shape_mnk), - tile_coord_mnkl(tile_coord_mnkl), - tiled_mma(tiled_mma), - epi_tile(epi_tile), - thread_idx(thread_idx) {} -}; - -template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma, - class EpilogueTile, - class TiledCopy, - class CoordTensor, - class Residue, - class ThrCoordTensor, - class ThrResidue, - class ThrSrcTensor -> -struct ConsumerStoreArgs { - ProblemShapeMNKL problem_shape_mnkl; - TileShapeMNK tile_shape_mnk; - TileCoordMNKL tile_coord_mnkl; - TiledMma tiled_mma; - EpilogueTile epi_tile; - TiledCopy tiled_copy; - CoordTensor cD; - Residue residue_cD; - ThrCoordTensor tCcD; - ThrResidue residue_tCcD; - ThrSrcTensor & tCrC; - int thread_idx; - - CUTLASS_DEVICE - ConsumerStoreArgs( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - CoordTensor cD, - Residue residue_cD, - ThrCoordTensor tCcD, - ThrResidue residue_tCcD, - ThrSrcTensor & tCrC, - int thread_idx) - : problem_shape_mnkl(problem_shape_mnkl), - tile_shape_mnk(tile_shape_mnk), - tile_coord_mnkl(tile_coord_mnkl), - tiled_mma(tiled_mma), - epi_tile(epi_tile), - tiled_copy(tiled_copy), - cD(cD), - residue_cD(residue_cD), - tCcD(tCcD), - residue_tCcD(residue_tCcD), - tCrC(tCrC), - thread_idx(thread_idx) {} -}; - -template -struct Sm90VisitorImplBase { - // Shared memory allocation - using SharedStorage = tuple; - // Host side fusion arguments - using Arguments = tuple; - // Device side fusion params (Kernel-entry API) - using Params = tuple; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - uint8_t* op_workspace = reinterpret_cast(workspace); - return transform_apply(tuple{}, args, - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - auto ret = Op::to_underlying_arguments(problem_shape, op_args, op_workspace); - if (op_workspace != nullptr) { - size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); - op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); - } - return ret; - }, - [] (auto&&... op_params) CUTLASS_LAMBDA_FUNC_INLINE { return cute::make_tuple(op_params...); } - ); - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return transform_apply(tuple{}, args, - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - return Op::can_implement(problem_shape, op_args); - }, - [&] (auto&&... implementable) CUTLASS_LAMBDA_FUNC_INLINE { - return (true && ... && implementable); - } - ); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return transform_apply(tuple{}, args, - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); - return round_nearest(op_workspace_size, MinWorkspaceAlignment); - }, - [&] (auto&&... op_workspace_size) CUTLASS_LAMBDA_FUNC_INLINE { - return (0 + ... + op_workspace_size); - } - ); - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* op_workspace = reinterpret_cast(workspace); - return transform_apply(tuple{}, args, - // Initialize each operation's workspace, stopping at the first error - [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { - if (status != Status::kSuccess) { - return status; - } - - using Op = cute::remove_cvref_t; - status = Op::initialize_workspace(problem_shape, op_args, op_workspace, stream, cuda_adapter); - if (op_workspace != nullptr) { - size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); - op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); - } - return status; - }, - // Return the final status - [&] (auto const&...ops) CUTLASS_LAMBDA_FUNC_INLINE { return status; } - ); - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops(transform_apply(tuple{}, params, shared_storage, - [] (auto&& op, auto const& op_params, auto&& op_storage) CUTLASS_LAMBDA_FUNC_INLINE { - using Op = cute::remove_cvref_t; - return Op(op_params, op_storage); - }, - [] (auto&&... ops) CUTLASS_LAMBDA_FUNC_INLINE { return cute::make_tuple(ops...); } - )) {} - - // Ops can store kernel persistent variables (e.g. descriptors, scalars, wave counters) - tuple ops; -}; - -template -struct Sm90VisitorImpl : Sm90VisitorImplBase { - - using Impl = Sm90VisitorImplBase; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90VisitorImpl() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - using Impl::ops; - - // - // Queries for kernel runtime - // - - // Is a specialized warp for producer TMA loads needed - // e.g. Aux tensor loads, broadcasts using TMA bulk copy - // This condition cannot change between work tiles because it is used - // to determine whether the load warp should exit early or not - // e.g. for batched beta this must always be true regardless of current batch idx - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return cute::apply(ops, - [] (auto const&... op) CUTLASS_LAMBDA_FUNC_INLINE { - return (false || ... || op.is_producer_load_needed()); - } - ); - } - - // Is a producer TMA load specifically for C needed - // If this is true then is_producer_load_needed must also be true - // This condition can change between work tiles because it is only used - // to determine whether the TMA and smem loads for C of a given tile should happen - // e.g. for batched beta this can be false depending on current batch idx - CUTLASS_DEVICE bool - is_C_load_needed() const { - return cute::apply(ops, - [] (auto const&... op) CUTLASS_LAMBDA_FUNC_INLINE { - return (false || ... || op.is_C_load_needed()); - } - ); - } - - // Producer load callbacks factory - // All operations must redefine this, but most can just dispatch to the base impl - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return transform_apply(ops, - [&] (auto& op) CUTLASS_LAMBDA_FUNC_INLINE { - return op.get_producer_load_callbacks(args); - }, - [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - auto callbacks_tuple = cute::make_tuple(callbacks...); - return ProducerLoadCallbacksImpl{callbacks_tuple}; - } - ); - } - - // Consumer store callbacks factory - // All operations must redefine this - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return transform_apply(ops, - [&] (auto& op) CUTLASS_LAMBDA_FUNC_INLINE { - return op.template get_consumer_store_callbacks(args); - }, - [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - auto callbacks_tuple = cute::make_tuple(callbacks...); - return ConsumerStoreCallbacksImpl{callbacks_tuple}; - } - ); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Convenience aliases -using EmptyProducerLoadCallbacks = ProducerLoadCallbacksImpl>; -using EmptyConsumerStoreCallbacks = ConsumerStoreCallbacksImpl>; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -using namespace detail; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Tree visitor -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Sm90TreeVisitor : Sm90VisitorImpl { - - using Impl = Sm90VisitorImpl; - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor( - Params const& params, - SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) {} - - using CallbacksImpl::callbacks_tuple; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - constexpr int Rm1 = sizeof...(ChildOps); - return cute::detail::tapply(callbacks_tuple, - [&] (auto& child_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { - return child_callbacks.visit(frg_acc, epi_v, epi_m, epi_n); // child ops must be nullary (e.g. loads, trees) - }, - [&] (auto&&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); - }, - make_seq{} // restrict the transform to R-1 child ops, apply is for node op - ); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_impl = Sm90VisitorImpl:: - template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(cute::move(callbacks_impl)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// DAG visitors -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Most DAG fusions can be represented as a set of output trees with a common input tree -// The common input is first evaluated, then the result is passed as the acc fragment to the output trees -template -struct Sm90SplitTreeVisitor : Sm90VisitorImpl { - - using Sm90VisitorImpl::Sm90VisitorImpl; - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) {} - - using CallbacksImpl::callbacks_tuple; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - Array frg_input = get<0>(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); - - constexpr int Rm2 = sizeof...(AuxOutTrees); - cute::for_each(make_seq{}, // restrict the sequence to aux out trees - [&] (auto I) CUTLASS_LAMBDA_FUNC_INLINE { - get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); - } - ); - - return get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_impl = Sm90VisitorImpl:: - template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(cute::move(callbacks_impl)); - } -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // deducing the output type for all the nodes is tricky so we just convert them all to a common type - // if multiple compute types are needed then split into multiple subgraphs grouped by type - class ElementCompute, - class EdgeTuple, // tuple of int_sequence, each sequence is the children indices (indexed by topological order) for each node - class... Ops // in topological order, last op is the output. EdgeTuple must match this order -> -struct Sm90TopologicalVisitor : Sm90VisitorImpl { - static_assert(is_static_v); - static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); - static_assert(sizeof...(Ops) > 1); - - using Sm90VisitorImpl::Sm90VisitorImpl; - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) {} - - using CallbacksImpl::callbacks_tuple; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - constexpr int Rm1 = sizeof...(Ops) - 1; - auto frg_compute_tuple = cute::repeat(Array{}); - - return cute::detail::tapply(EdgeTuple{}, callbacks_tuple, frg_compute_tuple, - // Visit the first R-1 ops in topological order - [&] (auto&& edge_seq, auto& callbacks, auto& frg_compute) CUTLASS_LAMBDA_FUNC_INLINE { - frg_compute = cute::detail::apply(frg_compute_tuple, - // Compute the current op with children inputs - [&] (auto const&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - auto frg_output = callbacks.visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); - using ElementOutput = typename decltype(frg_output)::Element; - using ConvertOutput = NumericArrayConverter; - ConvertOutput convert_output{}; - - return convert_output(frg_output); - }, - // Get inputs in the sequence given by the children indices of the current op - edge_seq - ); - return frg_compute; // unused - }, - // Visit the last op - [&] (auto const&...ops) CUTLASS_LAMBDA_FUNC_INLINE { - return cute::detail::apply(frg_compute_tuple, - // Compute the last op with children inputs - [&] (auto const&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { - return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); - }, - // Get inputs in the sequence given by the children indices of the last op - get(EdgeTuple{}) - ); - }, - // Transform to visit R-1 ops, apply to visit last op - make_seq{} - ); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_impl = Sm90VisitorImpl:: - template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(cute::move(callbacks_impl)); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Base specializations so we can have standard layout params and simple aggregate initializers -namespace detail { - -template -struct Sm90VisitorImplBase { - - // Retain tuple for SharedStorage because empty structs have 1B alignment - // tuples use multiple inheritance, avoids this problem - using SharedStorage = tuple< - typename Op0::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - }; - - struct Params { - typename Op0::Params op_0; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)) - }) {} - - tuple ops; -}; - -template -struct Sm90VisitorImplBase { - - using SharedStorage = tuple< - typename Op0::SharedStorage, - typename Op1::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - typename Op1::Arguments op_1; - }; - - struct Params { - typename Op0::Params op_0; - typename Op1::Params op_1; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); - uint8_t* op_0_workspace = reinterpret_cast(workspace); - uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), - Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0) && - Op1::can_implement(problem_shape, args.op_1); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)), - Op1(params.op_1, get<1>(shared_storage)) - }) {} - - tuple ops; -}; - -template -struct Sm90VisitorImplBase { - - using SharedStorage = tuple< - typename Op0::SharedStorage, - typename Op1::SharedStorage, - typename Op2::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - typename Op1::Arguments op_1; - typename Op2::Arguments op_2; - }; - - struct Params { - typename Op0::Params op_0; - typename Op1::Params op_1; - typename Op2::Params op_2; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); - size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); - uint8_t* op_0_workspace = reinterpret_cast(workspace); - uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; - uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), - Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), - Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0) && - Op1::can_implement(problem_shape, args.op_1) && - Op2::can_implement(problem_shape, args.op_2); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)), - Op1(params.op_1, get<1>(shared_storage)), - Op2(params.op_2, get<2>(shared_storage)) - }) {} - - tuple ops; -}; - -template -struct Sm90VisitorImplBase { - - using SharedStorage = tuple< - typename Op0::SharedStorage, - typename Op1::SharedStorage, - typename Op2::SharedStorage, - typename Op3::SharedStorage - >; - - struct Arguments { - typename Op0::Arguments op_0; - typename Op1::Arguments op_1; - typename Op2::Arguments op_2; - typename Op3::Arguments op_3; - }; - - struct Params { - typename Op0::Params op_0; - typename Op1::Params op_1; - typename Op2::Params op_2; - typename Op3::Params op_3; - }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); - size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); - size_t op_2_workspace_size = Op2::get_workspace_size(problem_shape, args.op_2); - uint8_t* op_0_workspace = reinterpret_cast(workspace); - uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; - uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; - uint8_t* op_3_workspace = op_2_workspace + op_2_workspace_size; - return Params{ - Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), - Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), - Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace), - Op3::to_underlying_arguments(problem_shape, args.op_3, op_3_workspace) - }; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - return Op0::can_implement(problem_shape, args.op_0) && - Op1::can_implement(problem_shape, args.op_1) && - Op2::can_implement(problem_shape, args.op_2) && - Op3::can_implement(problem_shape, args.op_3); - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - size_t workspace_size = 0; - workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += Op3::get_workspace_size(problem_shape, args.op_3); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - status = Op3::initialize_workspace(problem_shape, args.op_3, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += Op3::get_workspace_size(problem_shape, args.op_3); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) { - return status; - } - - return status; - } - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase() {} - - CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) - : ops({ - Op0(params.op_0, get<0>(shared_storage)), - Op1(params.op_1, get<1>(shared_storage)), - Op2(params.op_2, get<2>(shared_storage)), - Op3(params.op_3, get<3>(shared_storage)) - }) {} - - tuple ops; -}; - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp deleted file mode 100644 index bd378419567b1680c400ec38746211a577a3c409..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp +++ /dev/null @@ -1,763 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/workspace.h" - -#include "cute/tensor.hpp" -#include "sm90_visitor_tma_warpspecialized.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::epilogue::fusion { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Top-K + Softmax reduction across columns -// Performs a reduction of top-K values across N, and finally performs a softmax on them, -// and sets values not in the top-K to 0. -// -// Assumptions: -// 1. CTA_N >= N (single tile across N, the mode which is reduced) -// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one -// epilogue tile at a time.) -// 3. Top-K value is either 2 or 4. -// - -namespace detail { - -// Implementations for add to sorted list and merging sorted lists, -// with fast paths for lists of size 2 and 4 (Top-2 and Top-4). -// Generic implementations may result in greater register use and branching, -// and should be avoided. -// Fast paths for Top-2 and Top-4 are written in inline PTX directly. - -CUTLASS_DEVICE -Array top_2_reduce_scalar(Array a, float scalar) { - Array out; - asm volatile( - "{\n" - " .reg .f32 mx;\n" - " .reg .pred p;\n" - " max.f32 mx, %3, %4;\n" - " setp.gtu.f32 p, %2, %4;\n" - " selp.f32 %1, mx, %2, p;\n" - " selp.f32 %0, %2, %4, p;\n" - "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); - return out; -} - -CUTLASS_DEVICE -Array top_2_reduce(Array a, Array b) { - Array out; - asm volatile( - "{\n" - " .reg .v2 .f32 mx;\n" - " .reg .pred p;\n" - " max.f32 mx.x, %3, %4;\n" // max(a1, b0) - " max.f32 mx.y, %2, %5;\n" // max(a0, b1) - " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 - " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) - " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 - "}\n" : "=f"(out[0]), "=f"(out[1]) : - "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); - return out; -} - -CUTLASS_DEVICE -Array top_4_reduce_scalar(Array a, float scalar) { - Array out; - asm volatile( - "{\n" - " .reg .f32 mx;\n" // max(a3, b) - " .reg .pred p0;\n" // a0 > b - " .reg .pred p1;\n" // a1 > b - " .reg .pred p2;\n" // a2 > b - " max.f32 mx, %7, %8;\n" // max(a3, b) - " setp.gtu.f32 p0, %4, %8;\n" // a0 > b - " setp.gtu.f32 p1, %5, %8;\n" // a1 > b - " setp.gtu.f32 p2, %6, %8;\n" // a2 > b - " selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2 - " selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b - " selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1 - " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b - " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 - " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : - "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); - return out; -} - -CUTLASS_DEVICE -Array top_4_reduce(Array a, Array b) { - Array out; - asm volatile( - "{\n" - " .reg .f32 mxa0b1;\n" // max(a0, b1) - " .reg .f32 mxa1b0;\n" // max(a1, b0) - - " .reg .f32 mxa2b0;\n" // max(a2, b0) - " .reg .f32 mxa1b1;\n" // max(a1, b1) - " .reg .f32 mxa0b2;\n" // max(a1, b1) - - " .reg .f32 mxa1b2;\n" // max(a1, b2) - " .reg .f32 mxa2b1;\n" // max(a2, b1) - " max.f32 mxa1b2, %5, %10;\n" - " max.f32 mxa2b1, %6, %9;\n" - - " .reg .f32 mxa3b0;\n" // max(a1, b2) - " .reg .f32 mxa0b3;\n" // max(a2, b1) - " max.f32 mxa3b0, %7, %8;\n" - " max.f32 mxa0b3, %4, %11;\n" - - " .reg .pred pa0b0;\n" // a0 > b0 - " .reg .pred pa1b0;\n" // a1 > b0 - " .reg .pred pa2b0;\n" // a2 > b0 - " .reg .pred pa0b1;\n" // a0 > b1 - " .reg .pred pa1b1;\n" // a1 > b1 - " .reg .pred pa0b2;\n" // a0 > b2 - " .reg .pred pb2a0;\n" // b1 > a0 - " .reg .pred pb1a0;\n" // b1 > a0 - - " setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0 - " setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0 - " setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0 - " setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1 - " setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1 - " setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2 - - " not.pred pb2a0, pa0b2;\n" - " not.pred pb1a0, pa0b1;\n" - - " selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0) - " selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1) - - " selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1) - " selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0) - " selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2) - - // a0 - " selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0 - - // a1 - " selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1) - - // a2 - " mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case - " selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0 - " selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0 - - // a3 - " mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases - " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case - " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 - " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 - "}\n" : - "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : - "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), - "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); - return out; -} - -// Assumption: array elements are sorted in descending order -// (a[0] is the largest element in a[].) -template -CUTLASS_DEVICE -void add_element_to_desc_sorted_array(cutlass::Array& a, Element b) { - if constexpr (N == 2 && is_same_v) { - a = top_2_reduce_scalar(a, b); - } - else if constexpr (N == 4 && is_same_v) { - a = top_4_reduce_scalar(a, b); - } - else { - // slower generic path with branching, slower, and can cause register spill - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < N; ++k) { - if (a[k] < b) { - // Shift down - CUTLASS_PRAGMA_UNROLL - for (int l = N - 1; l > k; --l) { - a[l] = a[l-1]; - } - a[k] = b; - break; - } - } - } -} - -// Assumption: array elements are sorted in descending order -// (a[0] and b[0] are the largest elements in a[] and b[].) -template -CUTLASS_DEVICE -void merge_desc_sorted_arrays(cutlass::Array& a, const cutlass::Array& b) { - if constexpr (N == 2 && is_same_v) { - a = top_2_reduce(a, b); - } - else if constexpr (N == 4 && is_same_v) { - a = top_4_reduce(a, b); - } - else { - // slower generic path with branching, slower, and can cause register spill - int j = 0; - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < N; ++k) { - if (a[k] < b[j]) { - // Shift down - CUTLASS_PRAGMA_UNROLL - for (int l = N - 1; l > k; --l) { - a[l] = a[l-1]; - } - a[k] = b[j]; - ++j; - } - } - } -} - -// Assumption: array elements are sorted in descending order -// (a[0] is the largest element in a[].) -template -CUTLASS_DEVICE -Element topk_logsumexp(cutlass::Array a) { - // Do one less `exp`, because we know what its result will be. - // Assume x is a set of `x_i`s, and `x_m` is the maximum of that set. - // logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m)) - // Compute m + log(1 + sum_{i != m}(x_i - x_m)) - Element sum = Element(1.0); - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < N; ++i) { - sum += fast_exp(a[i] - a[0]); - } - return a[0] + fast_log(sum); -} - -CUTLASS_DEVICE -float fast_masked_softmax(float value, float minimum, float logsumexp) { - float new_value; - asm volatile( - "{\n" - " .reg .pred p0;\n" - // value >= minimum - " setp.geu.f32 p0, %1, %2;\n" - - " .reg .f32 x_lse;\n" - " .reg .f32 %%f<11>;\n" - " .reg .b32 %%r<3>;\n" - - // x_lse = value - minimum - " sub.rn.f32 x_lse, %1, %3;\n" - - // exp(x_lse) - // The following is derived from a ptx dump of expf. - // exp requires a base conversion from exp2. - " fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n" - " cvt.sat.f32.f32 %%f2, %%f1;\n" - " fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n" - " add.f32 %%f4, %%f3, 0fCB40007F;\n" - " neg.f32 %%f5, %%f4;\n" - " fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n" - " fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n" - " mov.b32 %%r1, %%f3;\n" - " shl.b32 %%r2, %%r1, 23;\n" - " mov.b32 %%f8, %%r2;\n" - " ex2.approx.ftz.f32 %%f9, %%f7;\n" - " mul.f32 %%f10, %%f9, %%f8;\n" - - // Mask or softmax - " selp.f32 %0, %%f10, 0f00000000, p0;\n" - "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); - return new_value; -} - -template -CUTLASS_DEVICE -Element masked_softmax(Element value, Element minimum, Element logsumexp) { - if constexpr (is_same_v) { - // Inline PTX implementation - // Significantly reduces register requirements - return fast_masked_softmax(value, minimum, logsumexp); - } - else { - return value < minimum ? Element(0.0) : fast_exp(value - logsumexp); - } -} - -} // namespace detail - -template < - int TopK, - int FragmentSize, - class CtaTileShapeMNK, - class EpilogueTile, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - int Alignment = 128 / sizeof_bits_v, - bool UseButterflyReduce = true -> -struct Sm90TopKSoftmaxColReduction { -private: - static_assert(is_same_v, "Fused Top-K + Softmax reduction requires FP32 accumulation."); - static_assert(TopK == 2 || TopK == 4, - "Fused Top-K + Softmax reduction only allows K=2 and K=4, because those cases have been performance-optimized. Other values of K can be enabled by removing this assertion, but they may come with serious performance implications." - ); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - - // Reduction tensors - // We have two tensors for this EVT node: a reduction tensor and a tensor holding - // final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax - // require different reductions, but those luckily overlap. Top-K obviously needs at least - // two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log - // of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the - // maximum of all x_i elements. Since safe softmax for any element x_i is computed as - // softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max)) - // we can track logsumexp instead of tracking two variables (sum of exps and the max). - // In addition, subtracting logsumexp from any element and taking its exp is equivalent to - // computing its softmax. - // - // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the - // way at all, because any element not in the top-K is going to be masked out and set to 0. - // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and - // keep it, and the smallest element in the top-K for masking out non-top-K elements. - // - // This means that our final reduction result will always be 2 elements, regardless of the value - // of K: minimum of top-K, and logsumexp. - // - // For each reduction tensor, we define a new struct for readability. - - struct ReductionResult { - ElementCompute min_; - ElementCompute logsumexp_; - - CUTLASS_DEVICE - ReductionResult() { } - - CUTLASS_DEVICE - ReductionResult(ElementCompute min, ElementCompute logsumexp): - logsumexp_(logsumexp), min_(min) { } - - // Warp shuffle broadcast - CUTLASS_DEVICE - void shuffle_up_sync(uint32_t delta, int lane_id) { - static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); - uint64_t r = reinterpret_cast(*this); - r = __shfl_up_sync(0xFFFFFFFF, r, delta); - *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; - } - }; - - struct TopKResult { - Array top_k_; - - CUTLASS_DEVICE - TopKResult() { - top_k_.fill(-cutlass::platform::numeric_limits::infinity()); - } - - // This is where we do the "final" reduction, where we compute - // the logsumexp for softmax, keep the smallest value in top-K, - // and discard the rest. - CUTLASS_DEVICE - ReductionResult reduce_final() const { - return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_)); - } - - // Butterfly reduction - CUTLASS_DEVICE - void shuffle_xor_sync(int laneMask) { - if constexpr (TopK == 2) { - static_assert(sizeof(TopKResult) == sizeof(uint64_t)); - uint64_t top_k = reinterpret_cast(*this); - top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); - auto synced_v = reinterpret_cast(top_k); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else if constexpr (TopK == 4) { - static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); - uint64_t* top_k_ptr = reinterpret_cast(this); - uint64_t top_k_arr[2]; - top_k_arr[0] = top_k_ptr[0]; - top_k_arr[1] = top_k_ptr[1]; - top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); - top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); - auto synced_v = reinterpret_cast(top_k_arr); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else { - TopKResult synced_v; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < TopK; ++i) { - synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); - } - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - } - - // Warp shuffle reduction - CUTLASS_DEVICE - void shuffle_down_sync(uint32_t delta) { - if constexpr (TopK == 2) { - static_assert(sizeof(TopKResult) == sizeof(uint64_t)); - uint64_t top_k = reinterpret_cast(*this); - top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); - auto synced_v = reinterpret_cast(top_k); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else if constexpr (TopK == 4) { - static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); - uint64_t* top_k_ptr = reinterpret_cast(this); - uint64_t top_k_arr[2]; - top_k_arr[0] = top_k_ptr[0]; - top_k_arr[1] = top_k_ptr[1]; - top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); - top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); - auto synced_v = reinterpret_cast(top_k_arr); - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - else { - TopKResult synced_v; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < TopK; ++i) { - synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); - } - detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); - } - } - }; - -public: - struct SharedStorage { }; - - struct Arguments { }; - - struct Params { }; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return {}; - } - - template - static bool - can_implement(ProblemShape const& problem_shape, Arguments const& args) { - auto [M, N, K, L] = problem_shape; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - // Cross CTA reduction is not possible because there is no guarantee that all CTAs run - // concurrently. - // Cross epilogue tile reduction is possible, but re-visiting and applying reduction - // to accumulators is only possible for the current epilogue tile. - auto [epi_M, epi_N] = EpilogueTile{}; - return N <= tile_N && N <= epi_N && N >= TopK; - } - - template - static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { - return 0; - } - - template - static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } - - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } - - CUTLASS_HOST_DEVICE - Sm90TopKSoftmaxColReduction() { } - - CUTLASS_HOST_DEVICE - Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage) - : params(params) { } - - Params params; - - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } - - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - CUTLASS_DEVICE - ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) - : args_tuple(cute::forward(args_tuple)), - params(params) {} - - ArgsTuple args_tuple; - Params const& params; - - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { - - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, - lane_layout_MN, lane_mn, - residue_cCol, residue_tCcCol] = args_tuple; - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - Array frg_I = convert_input(frg_input); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i); - if (elem_less(thread_crd, residue_tCcCol)) { - TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i); - detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]); - } - } - - return frg_input; - } - - template - CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { - - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, - lane_layout_MN, lane_mn, - residue_cCol, residue_tCcCol] = args_tuple; - - // fully OOB CTA in partially OOB cluster - if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { - return; - } - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - - // `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N, - // in order to reduce along modes in the `R2S` sublayout that correspond to N. - // This means we should modify and warp-reduce them according to their co-domain instead of - // their domain. Therefore we keep a filtered view of both and use them as necessary. - auto tCrTopK_f = filter(tCrTopK); - auto tCrSoftmax_f = filter(tCrSoftmax); - - // The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the - // last element of Top-K, use the latter to mask the visited results, and the former - // to apply softmax. - // - // This gives us two options: reduce the Top-K with warp shuffles, have the reduced - // lanes compute logsumexp and pair it with the last Top-K element, and broadcast - // the result back using warp shuffles. - // - // Alternatively, we can do a butterfly reduction over Top-K, and have all lanes - // compute their own logsumexp and skip the broadcast. - if constexpr (UseButterflyReduce) { - // - // 1. Butterfly reduction - // - CUTLASS_PRAGMA_UNROLL - for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrTopK_f); ++i) { - tCrTopK_f(i).shuffle_xor_sync(j); - } - } - - // - // 2. Strip down reduced value and compute sum of exps - // - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSoftmax_f); ++i) { - tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); - } - } - else { - // - // 1. Warp shuffle reduction - // - CUTLASS_PRAGMA_UNROLL - for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrTopK_f); ++i) { - tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols)); - } - } - - // - // 2. Strip down reduced value and compute sum of exps - // - bool is_reduced_lane = get<1>(lane_mn) == 0; - if (is_reduced_lane) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSoftmax_f); ++i) { - tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); - } - } - - // - // 3. Broadcast reduced values to all participants - // - CUTLASS_PRAGMA_UNROLL - for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrSoftmax_f); ++i) { - tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn)); - } - } - } - - // - // 4. Re-visit and apply top-K and softmax - // - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) { - auto& visit_frag = visit_results(epi_v); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentSize; ++i) { - visit_frag[i] = detail::masked_softmax( - visit_frag[i], - tCrSoftmax(epi_v * FragmentSize + i).min_, - tCrSoftmax(epi_v * FragmentSize + i).logsumexp_ - ); - } - } - - } - - CUTLASS_DEVICE void - end_loop(int epi_m, int epi_n) { - auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, - lane_layout_MN, lane_mn, - residue_cCol, residue_tCcCol] = args_tuple; - - // Reset reduced top-K values for next tile - // This must be done because we only assume a single epilogue tile across N, - // but not M. - fill(tCrTopK, TopKResult()); - } - - CUTLASS_DEVICE void - end() { } - - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - Layout ref_layout_MN = [&] () { - auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); - if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } - else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } - }(); // tile_mn -> tv_idx - - // Get the MN layout + coord of lanes to determine shuffle reduction iterations - using _W = Int; - Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx - Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx - Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx - Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn - int lane_idx = canonical_lane_idx(); - auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); - - // Get the MN layout + coord of warps to determine smem reduction iterations - Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx - Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx - Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx - - // Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction. - static_assert(decltype(size<1>(warp_layout_MN))::value <= 1); - - // Reduction layout - // We're assuming all elements in a row (over which we're performing the reduction) are - // visited in the same corresponding epilogue tile, and this is what allows us to apply the - // top-K + softmax operation within `reduce()`, by re-visiting the accumulated results. - // - // This presents a challenge, because the layout of the accumulated results is typically in - // in the register to shared memory shape, or: (R2S,R2S_M,R2S_N). - // This means that we still need to reduce this tensor along N. - // - // The solution is simple: we need to flatten the layout, identify modes that correspond to - // N and set their strides to 0, in order to map fragment indices corresponding to the same - // row back to the same element in the tensor. - // - // This requires some extra layout manipulation, which is as follows. - - // Create new accumulator layout with column broadcast - auto [M, N, K] = args.tile_shape_mnk; - auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx); - auto gColReduce = make_tensor( - make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N) - auto tCrColReduce = make_tensor_like( // (FrgV, MMA_M, MMA_N) - thr_mma.partition_C(gColReduce).layout()); - - // Tile the new accumulator tensor according to R2S - ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx); - Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) - auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) - - // Compose the new accumulator R2S layout with the expected tCrC layout to get final - // reduction tensor layout. - auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) - - Tensor tCrTopK = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) - Tensor tCrSoftmax = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) - fill(tCrTopK, TopKResult()); - - auto args_tuple = make_tuple( - cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD, - lane_layout_MN, lane_mn, - args.residue_cD, args.residue_tCcD); - return ConsumerStoreCallbacks(std::move(args_tuple), params); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::epilogue::fusion - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h b/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h deleted file mode 100644 index 8412b5037b3aacbca4d28b80b99839acb368d5df..0000000000000000000000000000000000000000 --- a/build/torch211-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h +++ /dev/null @@ -1,914 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief This extends the contents of cutlass/functional.h with frequently used activation functions. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/constants.h" -#include "cutlass/complex.h" -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/functional.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// If kIsHeavy is a member, use it. Otherwise, assume that it's false. -template -struct kIsHeavy_member_or_false { - static constexpr bool value = false; -}; -template -struct kIsHeavy_member_or_false::type> { - static constexpr bool value = Op::kIsHeavy; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Identity operator -template -struct Identity { - static const bool kIsHeavy = false; - - CUTLASS_HOST_DEVICE - T operator()(T value) const { - return value; - } -}; - -template -struct Identity > { - CUTLASS_HOST_DEVICE - Array operator()(Array value) const { - return value; - } -}; - -/// Scale operator -template -struct Scale { - struct Arguments { - using scale_type = T; - T scale = T(1); - }; - - CUTLASS_HOST_DEVICE - T operator()(T value, T scale) const { - multiplies mul; - return mul(scale, value); - } - - CUTLASS_HOST_DEVICE - T operator()(T value, Arguments args = Arguments()) const { - return this->operator()(value, args.scale); - } -}; - -template -struct Scale> { - using Arguments = typename Scale::Arguments; - - CUTLASS_HOST_DEVICE - Array operator()(Array values, T scale) const { - multiplies> mul; - return mul(scale, values); - } - - CUTLASS_HOST_DEVICE - Array operator()(Array values, Arguments args = Arguments()) const { - return this->operator()(values, args.scale); - } -}; - -/// Specialization to compose other activations with a defined unary operator -/// e.g. Scale> -template