diff --git a/.gitattributes b/.gitattributes index 47f34cbb3328da3f2d1261dd7e8f865b49fe8ae2..a4749a2c2e012db3683a05b8c4d27e9bf68bfe9b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,3 +13,4 @@ build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..ec717fbcc16d1ea94ef6fb3e114edef50c3b506b --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -0,0 +1,407 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator without splitk +template < + /// Shape of threadblock tile (concept: GemmShape) + typename Shape_, + /// Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + typename WarpMmaOperator_, + /// Number of partitions of the K dimension + int PartitionsK, + /// Tile iterator reading and writing output tensors + typename OutputTileIterator_, + /// Fragment iterator selecting accumulators + typename AccumulatorFragmentIterator_, + /// Output operator + typename OutputOp_, + /// Number of interleaved k + int InterleavedK> +class InterleavedEpilogue : + public EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_> +{ +public: + + using BaseStreamK = EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using OutputTileIterator = OutputTileIterator_; + using OutputOp = OutputOp_; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; + + /// Accumulator element + using ElementAccumulator = typename AccumulatorTile::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array; + + /// Array type used by output functor + using AccumulatorAccessType = + Array; + + /// Number of warps + using WarpCount = + gemm::GemmShape; + +public: + + static_assert(OutputTileIterator::kElementsPerAccess, + "This must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +public: + + /// Aspect for when epilogue source is not needed + struct SourceAspectNotNeeded + { + /// Constructor + CUTLASS_DEVICE + SourceAspectNotNeeded() + {} + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i]); + } + } + }; + + + /// Aspect for when epilogue source is needed + struct SourceAspectNeeded + { + OutputTileIterator source_iterator; + + typename OutputTileIterator::Fragment source_fragment; + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + static void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment, + typename OutputTileIterator::Fragment const &source_fragment) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + OutputAccessType const *source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); + } + } + + /// Constructor + CUTLASS_DEVICE + SourceAspectNeeded(OutputTileIterator source_iterator) : + source_iterator(source_iterator) + { + source_fragment.clear(); + } + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) + { + // Load addend source fragment from global memory + source_iterator.load(source_fragment); + ++source_iterator; + + apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); + } + }; + + + /// Shared storage allocation needed by the epilogue + struct SharedStorage {}; + + +public: + + /// Constructor + CUTLASS_DEVICE + InterleavedEpilogue( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx) ///< Id of thread within warp + : + BaseStreamK(thread_idx) + {} + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace, + /// performing epilogue computations, writing to output + CUTLASS_DEVICE + void reduce( + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + void *element_workspace, + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + { + // Redcuce peer accumulator fragments into one fragment + AccumulatorFragment accum_fragment; + BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); + + // Source-fragment data (zero-initialized for scenarios where the + // output operator allows us to skip loading it from global input) + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + if (output_op.is_source_needed()) + { + source_iterator += reduce_fragment_idx; + source_iterator.load(source_fragment); + } + + // Compute the output result + typename OutputTileIterator::Fragment output_fragment; + + // Apply the output operator + SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); + + // Store the final result + destination_iterator += reduce_fragment_idx; + destination_iterator.store(output_fragment); + } + + + /// Perform the epilogue computations and stream the result to global memory. + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements + /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + if (output_op.is_source_needed()) + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); + } + else + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); + } + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements a + /// single codepath, regardless of whether the output op requires addend data to be loaded + CUTLASS_DEVICE + void unified( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + if (!output_op.is_source_needed()) + { + source_iterator.clear_mask(); + __syncthreads(); // Dummy (CUDA 11.0) + } + + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); + } + + + /// Streams the result to global memory + template + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + SourceAspect source) + { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Convert fragment + // + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + source.apply_output_operator(output_fragment, output_op, accum_fragment); + + // + // Store the final result + // + + destination_iterator.set_iteration_index(iter); + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..6f6d101d088fd3bab48067f428e5bf01fa35a67a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/output_iterator_parameter.h @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template< + typename TensorLayout_, ///! The original output tensor layout + typename OutputIteratorLayout_, ///! Layout used by epilogue output iterator + typename TensorRef_, ///! Input tensor to epilogue output iterator + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = TensorLayout_; + using OutputIteratorLayout = OutputIteratorLayout_; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = ConvOperator; + using ConvProblemSize = ConvProblemSize_; + + /// Wgrad stride idx for implicit gemm algorithm + // Conv2d row-major matrix (KxRSC) + // Conv3d row-major matrix (KxTRSC) + static int const kWgradStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradStrideIdx : 0); + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(kTensorStrideIdx); + } + + CUTLASS_HOST_DEVICE + static OutputTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNHWC; + using OutputIteratorLayout = layout::TensorNHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kFprop; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNHWC; + using OutputIteratorLayout = layout::TensorNHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kDeconv; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNDHWC; + using OutputIteratorLayout = layout::TensorNDHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kFprop; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNDHWC; + using OutputIteratorLayout = layout::TensorNDHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kDeconv; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + +template < + int InterleavedK, + typename TensorRef_, + conv::Operator ConvOperator, + typename ConvProblemSize_ +> +struct ConvOutputIteratorParameter< + layout::TensorNCxHWx, + layout::TensorNCxHWx, + TensorRef_, + ConvOperator, + ConvProblemSize_> +{ + + using TensorLayout = typename layout::TensorNCxHWx; + using OutputIteratorLayout = typename layout::TensorNCxHWx; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = ConvOperator; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static OutputTensorCoord extent(ConvProblemSize problem_size) { + return problem_size.output_extent(); + } + +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h new file mode 100644 index 0000000000000000000000000000000000000000..2c011c1dc7268a9117b67bd75aefe8739f5441c7 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -0,0 +1,628 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tuple defining point in output tile +template < + int Column, + int Row, + int Group, + int Cluster, + int Tile +> +struct OutputTileShape { + static int const kColumn = Column; + static int const kRow = Row; + static int const kGroup = Group; + static int const kCluster = Cluster; + static int const kTile = Tile; + + static int const kCount = kColumn * kRow * kGroup * kCluster * kTile; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct OutputTileThreadMapHelpers { + + /// Determines the iteration index of a vector access according to the thread map + CUTLASS_HOST_DEVICE + static void iteration_index( + int &column_idx, + int &row_idx, + int &group_idx, + int &cluster_idx, + int &tile_idx, + int iter_idx) { + + column_idx = iter_idx % Iterations::kColumn; + int residual = iter_idx / Iterations::kColumn; + + row_idx = residual % Iterations::kRow; + residual = residual / Iterations::kRow; + + group_idx = residual % Iterations::kGroup; + residual = residual / Iterations::kGroup; + + cluster_idx = residual % Iterations::kCluster; + tile_idx = residual / Iterations::kCluster; + } + + /// Computes the offset of a given vector access + CUTLASS_HOST_DEVICE + static MatrixCoord iteration_offset(int iter_idx) { + + int column_idx; + int row_idx; + int group_idx; + int cluster_idx; + int tile_idx; + + iteration_index(column_idx, row_idx, group_idx, cluster_idx, tile_idx, iter_idx); + + return + MatrixCoord( + row_idx * Delta::kRow + + group_idx * Delta::kGroup + + cluster_idx * Delta::kCluster + + tile_idx * Delta::kTile, + + column_idx * Delta::kColumn); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + typename ThreadMap_, + typename Shape_, + typename Iterations_, + typename Delta_, + typename Count_ +> +struct OutputTileThreadMap : public OutputTileThreadMapHelpers { + + /// Conventional thread map (concept: ThreadMap) + using ThreadMap = ThreadMap_; + + /// Number of threads participating in the operation + static int const kThreads = ThreadMap::kThreads; + + /// Number of scalar elements per access + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Shape of the tile + using Shape = Shape_; + + /// Iterations performed by each thread + using Iterations = Iterations_; + + /// Delta between accesses + using Delta = Delta_; + + /// Number of iterator iterations + using Count = Count_; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + + using Index = typename layout::PitchLinearCoord::Index; + + layout::PitchLinearCoord coord = ThreadMap::initial_offset(thread_idx); + + Index cluster = coord.strided() / (Shape::kGroup * Shape::kRow); + Index cluster_residual = coord.strided() % (Shape::kGroup * Shape::kRow); + + Index group = cluster_residual / (Shape::kRow); + Index row = cluster_residual % (Shape::kRow); + + return MatrixCoord{ + row + group * Shape::kRow * Count::kRow + + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow, + coord.contiguous() + }; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// RowArrangement determines how one or more warps cover a region of consecutive rows. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize, + bool Is2dTile +> +struct RowArrangement; + +/// RowArrangement in which each warp's access is a 1D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangement { + static int const kWarpSize = 32; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + static int const kIterationsRow = 1; + static int const kDeltaRow = 1; + static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; + static int const kDeltaColumn = kWarpSize * kElementsPerAccess; + + static int const kAccessWidth = kWarpSize; + static int const kAccessRows = 1; + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = WarpsRemaining; +}; + +/// RowArrangement in which each warp's access is a 2D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangement { + + static int const kMemoryAccessSize = 256; // Preferred access size + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + struct Detail { + static int const kShapeRow = Shape::kRow / WarpsRemaining; + static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; + + static int const kTargetMemoryAccessWidth = + kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); + + static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; + }; + + static int const kAccessWidth = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + kWarpSize / Detail::kShapeRow + : const_min( + Detail::kShapeWidth, + const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) + )); + + static int const kAccessRows = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + Detail::kShapeRow + : const_min(Shape::kRow, kWarpSize / kAccessWidth)); + + static int const kIterationsRow = Detail::kShapeRow / kAccessRows; + static int const kDeltaRow = kAccessRows; + + static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; + static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; + + static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); + static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); + static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); + + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = 1; +}; + +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 4D space across warps to achieve several performance +/// objectives: +/// +/// - coalesced memory accesses in units of 128 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template < + typename Shape_, + typename Count_, + int Threads, + int ElementsPerAccess, + int ElementSize +> +struct OutputTileOptimalThreadMap { + + using Shape = Shape_; + using Count = Count_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail { + + // Clusters + static int const kIterationsCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kCluster / kWarpCount + : 1); + + static int const kDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kCompactedDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kWarpPartitionsCluster = + ((Shape::kCluster > kWarpCount) ? + kWarpCount + : kWarpCount / Shape::kCluster); + + static int const kWarpsRemainingForGroups = + ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); + + // Groups + static int const kIterationsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kGroup / kWarpsRemainingForGroups + : 1); + + static int const kDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kCompactedDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kWarpPartitionsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + static int const kWarpsRemainingForRows = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + // Rows + using RowArrangement = detail::RowArrangement< + Shape, + kWarpsRemainingForRows, + kElementsPerAccess, + kElementSize, + (Shape::kRow > kWarpsRemainingForRows) + >; + + // Warp partitions + using WarpPartitions = OutputTileShape< + RowArrangement::kWarpPartitionsColumn, + RowArrangement::kWarpPartitionsRow, + kWarpPartitionsGroup, + kWarpPartitionsCluster, + 1>; + + static int const kAccessWidth = RowArrangement::kAccessWidth; + static int const kAccessRows = RowArrangement::kAccessRows; + }; + + // + // Output + // + + using Iterations = OutputTileShape< + Detail::RowArrangement::kIterationsColumn, + Detail::RowArrangement::kIterationsRow, + Detail::kIterationsGroup, + Detail::kIterationsCluster, + 1>; + + using Delta = OutputTileShape< + Detail::RowArrangement::kDeltaColumn, + Detail::RowArrangement::kDeltaRow, + Detail::kDeltaGroup, + Detail::kDeltaCluster, + 1>; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + +// int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; + int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; + + int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; + int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; + + int row_idx = residual_group / Detail::WarpPartitions::kRow; + int col_idx = residual_group % Detail::WarpPartitions::kRow; + + // Compute per-lane offset + int lane_row_offset = lane_idx / Detail::kAccessWidth; + int lane_col_offset = lane_idx % Detail::kAccessWidth; + + // Compute coordinate in output space + int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; + int group_offset = group_idx * Shape::kRow * Count::kRow; + int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; + int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; + + return MatrixCoord( + cluster_offset + group_offset + row_offset + lane_row_offset, + column_offset + lane_col_offset * kElementsPerAccess + ); + } + + /// Computes the offset of a given vector access + CUTLASS_HOST_DEVICE + static MatrixCoord iteration_offset(int iter_idx) { + return OutputTileThreadMapHelpers::iteration_offset(iter_idx); + } + + /// Compacted thread map in which the 4D region is contiguous + struct CompactedThreadMap { + + + using Shape = Shape_; + + using TileShape = MatrixShape< + Shape::kTile * Shape::kCluster * Shape::kGroup * Shape::kRow, + Shape::kColumn + >; + + using Iterations = OutputTileShape< + Detail::RowArrangement::kIterationsColumn, + Detail::RowArrangement::kIterationsRow, + Detail::kIterationsGroup, + Detail::kIterationsCluster, + 1>; + + using Delta = OutputTileShape< + Detail::RowArrangement::kDeltaColumn, + Detail::RowArrangement::kDeltaRow, + Detail::kCompactedDeltaGroup, + Detail::kCompactedDeltaCluster, + 1>; + + /// Number of elements within each vector access + static int const kElementsPerAccess = ElementsPerAccess; + + /// Number of threads + static int const kThreads = Threads; + + /// Function to compute each thread's initial offset + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + +// int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; + int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; + + int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; + int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; + + int row_idx = residual_group / Detail::WarpPartitions::kRow; + int col_idx = residual_group % Detail::WarpPartitions::kRow; + + // Compute per-lane offset + int lane_row_offset = lane_idx / Detail::kAccessWidth; + int lane_col_offset = lane_idx % Detail::kAccessWidth; + + // Compute coordinate in output space + int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup; + int group_offset = group_idx * Shape::kRow; + int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; + int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; + + MatrixCoord coord( + cluster_offset + group_offset + row_offset + lane_row_offset, + column_offset + lane_col_offset * kElementsPerAccess + ); + + return coord; + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 3D interleaved layout across warps +/// to achieve several performance objectives: +/// +/// - coalesced memory accesses in units of 64 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template +struct InterleavedOutputTileThreadMap { + using WarpCount = WarpCount_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail {}; + + // + // Output + // + + using Iterations = Iterations_; + + using Delta = layout::PitchLinearShape; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static layout::PitchLinearCoord initial_offset(int thread_idx) { + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + layout::PitchLinearCoord warp_footprint{ + Delta::kContiguous * Iterations::kContiguous, + Delta::kStrided * Iterations::kStrided}; + + layout::PitchLinearCoord warp_offset{warp_idx % WarpCount::kContiguous, + warp_idx / WarpCount::kContiguous}; + + // Compute per-lane offset + layout::PitchLinearCoord thread_offset_in_warp{ + lane_idx * kElementsPerAccess, 0}; + + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + warp_footprint * warp_offset + thread_offset_in_warp; + + return thread_offset_in_threadblock_tile; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 4D interleaved layout across warps +/// to achieve several performance objectives: +/// +/// - coalesced memory accesses in units of 64 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template +struct InterleavedConvOutputTileThreadMap { + using WarpCount = WarpCount_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail {}; + + // + // Output + // + + using Iterations = Iterations_; + + using Delta = MatrixShape; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + MatrixCoord warp_footprint{ + Delta::kRow * Iterations::kRow, + Delta::kColumn * Iterations::kColumn, + }; + + MatrixCoord warp_offset{warp_idx % WarpCount::kRow, + warp_idx / WarpCount::kRow}; + + // Compute per-lane offset + MatrixCoord thread_offset_in_warp{lane_idx / 4, + (lane_idx % 4) * kElementsPerAccess}; + + MatrixCoord thread_offset_in_threadblock_tile = + warp_footprint * warp_offset + thread_offset_in_warp; + + return thread_offset_in_threadblock_tile; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..7c4692ffa29519b137dbb6dfb5918a85794178d2 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -0,0 +1,1387 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not + bool UseCUDAStore = false +> +class PredicatedTileIterator { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static bool constexpr PermuteD = !layout::is_trivial_permute; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + cutlass::Tensor4DCoord const &tensor_extent): + Params(layout) + { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + cutlass::Tensor5DCoord const &tensor_extent): + Params(layout) + { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). + uint8_t *byte_pointer_; + + /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_. + uint8_t *store_byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const *indices_; + + /// PermuteDLayout + PermuteDLayout permute_layout_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIterator( + PredicatedTileIteratorParams const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const *indices = nullptr + ): + params_(params), indices_(indices), + permute_layout_(PitchLinearCoord(extent.column(), extent.row()), params_.stride * kElementsPerAccess / sizeof(AccessType)) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize byte_pointer_ + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used. + store_byte_pointer_ = PermuteD ? reinterpret_cast(pointer) : byte_pointer_; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = store_byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + if (PermuteD) { + + int col_offset = column * ThreadMap::Delta::kColumn; + + int col = col_offset + thread_start_column_; + int row = row_offset + thread_start_row_; + + // Locate memory_pointer + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + permute_layout_(PitchLinearCoord(col, row)) * sizeof(AccessType) / kElementsPerAccess); + } + + if (UseCUDAStore) { + if (guard) { + memory_pointer[0] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[0], + guard); + } + + if (!PermuteD) { + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_group; + } + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + if (!ScatterD && !PermuteD) { + byte_pointer += params_.increment_cluster; + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P/2) * (convolution_Q/2) + + ((output_P + row_add_P)/2) * (convolution_Q/2) + (output_Q + row_add_Q)/2; + + int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_group; + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_group; + } + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_cluster; + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_cluster; + } + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + + if (!ScatterD) { + byte_pointer_ += params_.advance_tile; + } + + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_tile; + } + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow + * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; + } + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + byte_pointer_ += (params_.advance_row * increment); + store_byte_pointer_ += (params_.advance_row * increment); + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + byte_pointer_ += (params_.advance_group * increment_row); + store_byte_pointer_ += (params_.advance_group * increment_row); + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * + increment_row; + + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + byte_pointer_ += (params_.advance_cluster * increment_group); + store_byte_pointer_ += (params_.advance_cluster * increment_group); + thread_start_row_ += + ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * + ThreadMap::Shape::kRow * + increment_group; + + // Tile + byte_pointer_ += (params_.advance_tile * increment_cluster); + store_byte_pointer_ += (params_.advance_tile * increment_cluster); + thread_start_row_ += + ThreadMap::Shape::kGroup * + ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * + ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + int InterleavedN ///< Number of Interleaved N +> +class InterleavedPredicatedTileIterator { +public: + using ThreadMap = ThreadMap_; + + using Element = Element_; + + using Layout = layout::ColumnMajorInterleaved; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = layout::PitchLinearCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Iterations::kCount; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Uses a non-template class + struct Params : InterleavedPredicatedTileIteratorParams { + using Base = InterleavedPredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + Base( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_InterleavedPredicatedTileIteratorDesc() + ) { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + static int const kCount = (ThreadMap::Iterations::kContiguous < 8) + ? 8 + : ThreadMap::Iterations::kContiguous; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in columns + Index extent_col_; + + /// A thread's starting column position (assuming steady-state predicates have + /// been computed) + Index thread_start_col_; + + /// Internal iteration counter + int iteration_contiguous_; + + int iteration_strided_; + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + InterleavedPredicatedTileIterator( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + params_(params) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + + TensorCoord(threadblock_offset.contiguous() * InterleavedN, + threadblock_offset.strided() / InterleavedN); + + extent_col_ = extent.strided() / InterleavedN; + thread_start_col_ = thread_offset.strided(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + mask_.predicates[c] = + ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) < + (extent.contiguous() * InterleavedN)); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.strided()) * LongIndex(params_.stride) + + LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess; + + // Initialize internal state counter + iteration_contiguous_ = iteration_strided_ = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + AccessType *memory_pointer = reinterpret_cast(byte_pointer); + + int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; + + bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); + + bool guard = col_guard && mask_.predicates[iteration_contiguous_]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + *frag_ptr, + (void *)memory_pointer, + guard); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + AccessType *memory_pointer = reinterpret_cast(byte_pointer); + + int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; + + bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); + + bool guard = col_guard && mask_.predicates[iteration_contiguous_]; + + cutlass::arch::global_store( + *frag_ptr, (void *)memory_pointer, guard); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int iteration) { + iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous; + iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIterator &operator++() { + + ++iteration_contiguous_; + byte_pointer_ += params_.advance_row; + + if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) { + + iteration_contiguous_ = 0; + ++iteration_strided_; + byte_pointer_ += params_.advance_column; + + if (iteration_strided_ == ThreadMap::Iterations::kStrided) { + iteration_strided_ = 0; + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIterator &operator+=(int increment) + { + // Contiguous + iteration_contiguous_ += increment; + int increment_strided = iteration_contiguous_ / ThreadMap::Iterations::kContiguous; + iteration_contiguous_ = iteration_contiguous_ % ThreadMap::Iterations::kContiguous; + byte_pointer_ += (params_.advance_row * increment); + + // Strided + iteration_strided_ += increment_strided; + byte_pointer_ += (params_.advance_column * increment_strided); + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | InterleavedMaskedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + int InterleavedN ///< Number of Interleaved N +> +class InterleavedConvPredicatedTileIterator { +public: + using ThreadMap = ThreadMap_; + + using Element = Element_; + + using Layout = layout::TensorNCxHWx; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = Tensor4DCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Iterations::kCount; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + struct Params { + + // + // Data members + // + + LongIndex stride_col; ///< stride in bytes between columns + LongIndex stride_row; ///< stride in bytes between rows + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(typename Layout::Stride stride_) { + stride_col = stride_[1]; + stride_row = stride_[2]; + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Params() { + initialize(cutlass::make_Coord(0, 0, 0)); + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout) { + + initialize(layout.stride()); + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + cutlass::Tensor4DCoord const &tensor_extent): + Params(layout) + { } + + }; + + /// Mask object + struct Mask { + static int const kCount = + (ThreadMap::Iterations::kRow < 8) ? 8 : ThreadMap::Iterations::kRow; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in columns + Index extent_col_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in pq + Index extent_pq_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have + /// been computed) + Index thread_start_col_; + + /// Internal iteration counter + LongIndex iteration_row_; + LongIndex iteration_col_; + + uint32_t pq_mul_; + + uint32_t pq_shr_; + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + InterleavedConvPredicatedTileIterator( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + MatrixCoord threadblock_offset + ): + params_(params) { + MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_col_ = extent.c(); + extent_pq_ = extent.h() * extent.w(); + extent_row_ = extent.n() * extent_pq_; + + find_divisor(pq_mul_, pq_shr_, extent_pq_); + + thread_start_row_ = thread_offset.row(); + thread_start_col_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < ThreadMap::Iterations::kRow; ++r) { + mask_.predicates[r] = + ((thread_offset.row() + ThreadMap::Delta::kRow * r) < extent_row_); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + ((thread_start_col_ / InterleavedN) * params_.stride_col + + (thread_start_col_ % InterleavedN)) * + sizeof_bits::value / 8; + + // Initialize internal state counter + iteration_row_ = iteration_col_ = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; + bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); + bool guard = col_guard && mask_.predicates[iteration_row_]; + + int n, pq_rem; + + fast_divmod(n, pq_rem, + thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, + extent_pq_, pq_mul_, pq_shr_); + + uint8_t *byte_pointer = + byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * + sizeof_bits::value / 8; + AccessType *frag_ptr = reinterpret_cast(&frag); + AccessType const *memory_pointer = + reinterpret_cast(byte_pointer); + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + *frag_ptr, + (void *)memory_pointer, + guard); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + + int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; + bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); + bool guard = col_guard && mask_.predicates[iteration_row_]; + + int n, pq_rem; + + fast_divmod(n, pq_rem, + thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, + extent_pq_, pq_mul_, pq_shr_); + + uint8_t *byte_pointer = + byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * + sizeof_bits::value / 8; + AccessType const *frag_ptr = reinterpret_cast(&frag); + AccessType *memory_pointer = reinterpret_cast(byte_pointer); + + cutlass::arch::global_store( + *frag_ptr, (void *)memory_pointer, guard); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int iteration) { + iteration_row_ = iteration % ThreadMap::Iterations::kRow; + iteration_col_ = iteration / ThreadMap::Iterations::kRow; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + InterleavedConvPredicatedTileIterator &operator++() { + + ++iteration_row_; + + if (iteration_row_ == ThreadMap::Iterations::kRow) { + + iteration_row_ = 0; + ++iteration_col_; + byte_pointer_ += params_.stride_col; + + if (iteration_col_ == ThreadMap::Iterations::kColumn) { + iteration_col_ = 0; + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h new file mode 100644 index 0000000000000000000000000000000000000000..7068c39409f4a25d97e6adfb5360d2c0d226e1e6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h @@ -0,0 +1,615 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +/// It provides a fast path for the case Rank = 2 which does not need div/rem to +/// calculate modes. + +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + int Rank +> +class PredicatedTileIteratorAffineRankN { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::AffineRankN; + using TensorRef = TensorRef; + using TensorView = TensorView; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = typename Layout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + static_assert( !(Layout::kRank % 2), + "Layout rank must be even. This assumes the first half of the modes correspond to the 'row' " + "and the second half of the modes correspond to the 'column'"); + + static bool const kBigEndian = false; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Parameters structure + struct Params { + + // + // Data members + // + + Layout layout; + + /// Stride in units of bytes along M modes + Coord stride_m; + + /// Stride in units of bytes along N modes + Coord stride_n; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + int64_t rank2_inc_col; + int64_t rank2_inc_row; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(TensorCoord const &extent, Layout const &layout_): layout(layout_) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + if (kBigEndian) { + // "Big Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i + 1]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); + } + } + else { + // "Little Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); + } + } + + #if 0 + // + // Debug print statements to verify extents and strides are passed correctly. + // + printf("PredicatedTileIteratorAffine::Params() entered\n"); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + printf(" extent[%d]: %d\n", i, extent[i]); + } + for (int i = 0; i < Layout::kRank; ++i) { + printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); + } + printf("PredicatedTileIteratorAffine::Params() returning\n"); + #endif + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout_): layout(layout_) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0]; + rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0]; + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in columns + Index extent_col_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have been computed) + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Offsets in columns, cached for performance + int64_t offset_modes_n_[ThreadMap::Iterations::kColumn]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorAffineRankN( + Params const & params, + Element *pointer, + MatrixCoord extent, + int thread_idx, + MatrixCoord threadblock_offset = MatrixCoord(), + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + params_(params) + { + + MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_col_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + if (Layout::kRank > 2) { + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + // + // Compute coordinate and decompose into N modes + // + + int coord_n = thread_start_column_ + c * ThreadMap::Delta::kColumn; + + mask_.predicates[c] = coord_n < extent.column(); + + Coord modes_n; + + int64_t offset_modes_n = 0; + + if (kBigEndian) { + modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); + + offset_modes_n = dot(modes_n, params_.stride_n); + } + else { + modes_n = CoordinateDecompositionLittleEndian(coord_n, params_.divmod_n); + + offset_modes_n = dot(modes_n, params_.stride_n); + } + + offset_modes_n_[c] = offset_modes_n; + + } + + if (!pointer) { + mask_.clear(); + } + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { + uint8_t const *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + int64_t offset_modes_m = row_begin * params_.stride_m[0]; + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + // + // Compute coordinate and decompose into M modes + // + + int coord_m = row * ThreadMap::Delta::kRow + row_begin; + + Coord modes_m; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); + } else { + modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); + } + + offset_modes_m = dot(modes_m, params_.stride_m); + } + + // + // Compute the offset due to modes M + // + + bool row_guard = (coord_m < extent_row_); + int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + // + // Compute coordinate and decompose into N modes + // + + if (Layout::kRank > 2) { + offset_modes_n = offset_modes_n_[column]; + } + + // + // Compute the pointer and access + // + bool guard; + + if (Layout::kRank > 2) { + guard = row_guard && mask_.predicates[column]; + } else { + guard = (coord_m < extent_row_) && + ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); + } + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), + guard + ); + + if (Layout::kRank == 2) { + offset_modes_n += params_.rank2_inc_col; + } + } + + if (Layout::kRank == 2) { + offset_modes_m += params_.rank2_inc_row; + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + int64_t offset_modes_m = row_begin * params_.stride_m[0]; + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + // + // Compute coordinate and decompose into M modes + // + + int coord_m = row * ThreadMap::Delta::kRow + row_begin; + + Coord modes_m; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_m = CoordinateDecomposition(coord_m, params_.divmod_m); + } else { + modes_m = CoordinateDecompositionLittleEndian(coord_m, params_.divmod_m); + } + + offset_modes_m = dot(modes_m, params_.stride_m); + } + + // + // Compute the offset due to modes M + // + + bool row_guard = (coord_m < extent_row_); + int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + // + // Compute coordinate and decompose into N modes + // + + if (Layout::kRank > 2) { + offset_modes_n = offset_modes_n_[column]; + } + + // + // Compute the pointer and access + // + bool guard; + if (Layout::kRank > 2) { + guard = row_guard && mask_.predicates[column]; + } else { + guard = (coord_m < extent_row_) && ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < extent_col_); + } + + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), + guard); + + if (Layout::kRank == 2) { + offset_modes_n += params_.rank2_inc_col; + } + } + + if (Layout::kRank == 2) { + offset_modes_m += params_.rank2_inc_row; + } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + + store_with_byte_offset(frag, 0); + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineRankN &operator++() { + + ++state_[0]; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h new file mode 100644 index 0000000000000000000000000000000000000000..9990dbdbfc1445f91df4aa6a1eb5776663a5d4fd --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_affine_layout_params.h @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/fast_math.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Rank +> +struct PredicatedTileIteratorAffineLayoutRankNParams { + using Layout = layout::AffineRankN; + using TensorCoord = typename Layout::TensorCoord; + + static bool const kBigEndian = false; + + // + // Data members + // + + Layout layout; + + /// Stride in units of bytes along M modes + Coord stride_m; + + /// Stride in units of bytes along N modes + Coord stride_n; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank/2 - 1)]; + + int64_t rank2_inc_col; + int64_t rank2_inc_row; + + // + // Methods + // + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineLayoutRankNParams() { } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineLayoutRankNParams(TensorCoord const &extent, + Layout const &layout_, + int64_t element_sizeof_bits) + : layout(layout_) + { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); + } + + if (kBigEndian) { + // "Big Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i + 1]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); + } + } + else { + // "Little Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); + } + } + + #if 0 + // + // Debug print statements to verify extents and strides are passed correctly. + // + printf("PredicatedTileIteratorAffine::Params() entered\n"); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + printf(" extent[%d]: %d\n", i, extent[i]); + } + for (int i = 0; i < Layout::kRank; ++i) { + printf(" stride[%d]: %ld\n", i, layout_.stride()[i]); + } + printf("PredicatedTileIteratorAffine::Params() returning\n"); + #endif + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineLayoutRankNParams(Layout const &layout_, + int32_t threadmap_delta_kColumn, + int32_t threadmap_delta_kRow, + int64_t element_sizeof_bits) + : layout(layout_) + { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i], element_sizeof_bits); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2], element_sizeof_bits); + } + + rank2_inc_col = threadmap_delta_kColumn * stride_n[0]; + rank2_inc_row = threadmap_delta_kRow * stride_m[0]; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h new file mode 100644 index 0000000000000000000000000000000000000000..518ad0908c48a7e99b5cdb87792fd4b1a6d2672d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h @@ -0,0 +1,633 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + BlasMode BlasMode_ = BlasMode::kGemm ///< Tile Iterator for a Symmetric or Hermitian Kernel +> +class PredicatedTileIteratorBlas3 { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static BlasMode const kBlasMode = BlasMode_; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + static_assert( AccessType::kElements == 1, "BLAS3 Epilogue must use AccessType::kElements as 1"); + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { + + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Fill Mode for a tile on diagonal of a symmetric kernel + cutlass::FillMode fill_mode; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// Internal state counter + int state_[3]; + + /// Starting address of the matrix + size_t matrix_start_addr; + + static_assert((kBlasMode == BlasMode::kSymmetric || kBlasMode == BlasMode::kHermitian), + "Unsupported blas3 mode."); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorBlas3( + PredicatedTileIteratorParams const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset + , cutlass::FillMode fill_mode + ): + params_(params), fill_mode(fill_mode) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + thread_start_row_ = thread_offset.row(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Check Symmetric kernel modes (Lower and Upper - for diagonal CTAs, None for rest CTAs) + if ((kBlasMode == BlasMode::kSymmetric || kBlasMode == BlasMode::kHermitian) && + fill_mode == cutlass::FillMode::kInvalid) { + arch::device_breakpoint(); + } + + // Starting address of the matrix + matrix_start_addr = reinterpret_cast(pointer); + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment on the diagonal of a symmetric kernel to memory + CUTLASS_DEVICE + void load_symmetric_with_byte_offset(Fragment &frag, int64_t byte_offset) { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + bool isLowerMode = (fill_mode == cutlass::FillMode::kLower) ? true : false; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + // Offset of row from beginning of the matrix per thread + size_t row_start_offset = (size_t)memory_pointer - matrix_start_addr; + + // Absolute row index + int row_index = int(row_start_offset/params_.stride); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + // Offset of column from beginning of row per thread + size_t col_start_offset = row_start_offset + + (column * ThreadMap::Delta::kColumn / kElementsPerAccess) * sizeof(AccessType); + + // Absolute column index + size_t col_index = (col_start_offset%params_.stride)/sizeof(AccessType); + guard = guard && ( (isLowerMode && row_index >= col_index) || + (!isLowerMode && row_index <= col_index) ); + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + + // The imaginary parts of the diagonal elements of a complex element are assumed and set to zero + if (guard && kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { + Element *scalar_ptr = reinterpret_cast(frag_ptr); + + if (row_index == col_index) { + scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = + real(scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]); + } + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + if (fill_mode == cutlass::FillMode::kNone) { + load_with_byte_offset(frag, 0); + } + else { + load_symmetric_with_byte_offset(frag, 0); + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment on the diagonal of a symmetric kernel to memory + CUTLASS_DEVICE + void store_symmetric_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + bool isLowerMode = (fill_mode == cutlass::FillMode::kLower) ? true : false; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + // Offset of row from beginning of the matrix per thread + size_t row_start_offset = (size_t)memory_pointer - matrix_start_addr; + + // Absolute row index + int row_index = int(row_start_offset/params_.stride); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + // Offset of column from beginning of row per thread + size_t col_start_offset = row_start_offset + + (column * ThreadMap::Delta::kColumn / kElementsPerAccess) * sizeof(AccessType); + + // Absolute column index + size_t col_index = (col_start_offset%params_.stride)/sizeof(AccessType); + + guard = guard && ( (isLowerMode && row_index >= col_index) || + (!isLowerMode && row_index <= col_index) ); + + // The imaginary parts of the diagonal elements of a complex element are assumed and set to zero + if (guard && kBlasMode == BlasMode::kHermitian && cutlass::is_complex::value) { + + AccessType *frag_ptr_modify = const_cast(frag_ptr); + Element *scalar_ptr = reinterpret_cast(frag_ptr_modify); + + if (row_index == col_index) { + scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = + real(scalar_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]); + } + } + + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + + if (fill_mode == cutlass::FillMode::kNone) { + store_with_byte_offset(frag, 0); + } + else { + store_symmetric_with_byte_offset(frag, 0); + } + + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorBlas3 &operator++() { + + ++state_[0]; + byte_pointer_ += params_.advance_row; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..49ee22efad4bb40366b6358dba13ec689a3e059d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIteratorConv | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not + bool UseCUDAStore = false, + int Rank = 4 +> +class PredicatedTileIteratorConv { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + static int const kRank = Rank; + using Layout = typename platform::conditional::type; + + using Stride = typename Layout::Stride; + static int const kStrideRank = Layout::kStrideRank; + + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using MappedLayout = layout::RowMajor; + using Index = typename MappedLayout::Index; + using LongIndex = typename MappedLayout::LongIndex; + using TensorCoord = typename MappedLayout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static bool constexpr PermuteD = !layout::is_trivial_permute; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod[kStrideRank - 1]; + Stride tensor_stride; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, cutlass::Tensor4DCoord const &tensor_extent): + PredicatedTileIteratorParams( + layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) { + divmod[0] = FastDivmod(tensor_extent[2] /* Q for Fprop & W for Deconv*/); + divmod[1] = FastDivmod(tensor_extent[1] /* P for Fprop & H for Deconv*/); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + tensor_stride[i] = layout.stride()[i]; + } + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, cutlass::Tensor5DCoord const &tensor_extent): + PredicatedTileIteratorParams( + layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) { + divmod[0] = FastDivmod(tensor_extent[3] /* Q for Fprop & W for Deconv*/); + divmod[1] = FastDivmod(tensor_extent[2] /* P for Fprop & H for Deconv*/); + divmod[2] = FastDivmod(tensor_extent[1] /* Z for Fprop & D for Deconv*/); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + tensor_stride[i] = layout.stride()[i]; + } + } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorConv( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + // Initialize byte_pointer_ + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + Stride tensor_coord = CoordinateDecompositionLittleEndian(row_offset + thread_start_row_, params_.divmod); + + LongIndex tensor_offset = dot(tensor_coord, params_.tensor_stride); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess + tensor_offset / kElementsPerAccess], + guard); + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + Stride tensor_coord = CoordinateDecompositionLittleEndian((row_offset + thread_start_row_), params_.divmod); + + LongIndex tensor_offset = dot(tensor_coord, params_.tensor_stride); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[tensor_offset / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[tensor_offset / kElementsPerAccess], + guard); + } + + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorConv &operator++() { + + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow + * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; + } + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorConv &operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * + increment_row; + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + thread_start_row_ += + ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * + ThreadMap::Shape::kRow * + increment_group; + + // Tile + thread_start_row_ += + ThreadMap::Shape::kGroup * + ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * + ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..0d1f171100d40fa8fd07d643b28c547d817cae56 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h @@ -0,0 +1,445 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap) + typename Element_, ///< Element data type + typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>, + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> +> +class PredicatedTileIteratorDirectConv { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + using ThreadOutputShape = ThreadOutputShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + + using ConvProblemSize = typename cutlass::conv::Conv2dProblemSize; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + static int const kLoadsPerAccess = AccessType::kElements / AccessType::kElements; + + using ThreadTileCount = MatrixShape< + ThreadBlockOutputShape::kH / ThreadOutputShape::kH, + ThreadBlockOutputShape::kW / ThreadOutputShape::kW + >; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorDirect2dConvParams { + using Base = PredicatedTileIteratorDirect2dConvParams; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, cutlass::conv::Conv2dProblemSize const &problem_size): + PredicatedTileIteratorDirect2dConvParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + problem_size, + {ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW} + ) + { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kContiguous; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorDirect2dConvParams params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// + Element *pointer_; + + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Initial thread output location + int thread_start_n_, thread_start_p_, thread_start_q_; + + /// Current threadblock tile index + int tile_index_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorDirect2dConvParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + + + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorDirectConv( + PredicatedTileIteratorDirect2dConvParams const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params), pointer_(pointer) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + // stride dim (PQ) + thread_start_row_ = thread_offset.column(); + // contiguous dim (Channels) + thread_start_column_ = threadblock_offset.column() + thread_offset.row(); + + tile_index_ = threadblock_offset.row(); + + set_tile_index(0); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void set_tile_index(const int index) { + + int residual; + params_.pq_divmod(thread_start_n_, residual, tile_index_ + index); + params_.q_divmod(thread_start_p_, thread_start_q_, residual); + + // Compute the base output coord of ThreadBlock + thread_start_p_ *= ThreadBlockOutputShape::kH; + thread_start_q_ *= ThreadBlockOutputShape::kW; + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + mask_.predicates[c] = ((thread_start_column_ + + c * ThreadMap::Delta::kContiguous) < extent_column_); + } + + // Null pointer performs no accesses + if (!pointer_) { + mask_.clear(); + } + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; + + int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; + int p = current_row / ThreadBlockOutputShape::kW; + int q = current_row % ThreadBlockOutputShape::kW; + + int current_p = thread_start_p_ + p; + int current_q = thread_start_q_ + q; + + bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && + (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; + + int output_row_offset = + thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; + + uint8_t *byte_pointer = + reinterpret_cast(pointer_) + + LongIndex(output_row_offset) * LongIndex(params_.stride) + + LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * + sizeof(AccessType) / kElementsPerAccess; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + bool guard = row_guard && mask_.predicates[c]; + + cutlass::arch::global_load( + frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; + + int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; + int p = current_row / ThreadBlockOutputShape::kW; + int q = current_row % ThreadBlockOutputShape::kW; + + int current_p = thread_start_p_ + p; + int current_q = thread_start_q_ + q; + + bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && + (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; + + int output_row_offset = + thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; + + uint8_t *byte_pointer = + reinterpret_cast(pointer_) + + LongIndex(output_row_offset) * LongIndex(params_.stride) + + LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * + sizeof(AccessType) / kElementsPerAccess; + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + bool guard = row_guard && mask_.predicates[c]; + + cutlass::arch::global_store( + frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirectConv &operator++() { + // do nothing + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h new file mode 100644 index 0000000000000000000000000000000000000000..11ec3d72ea14fd23a99ead9a52fe14f947436a1a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -0,0 +1,483 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct OutputTileShapeDesc { + + int column; + int row; + int group; + int cluster; + int tile; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + OutputTileShapeDesc(): column(0), row(0), group(0), cluster(0), tile(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + OutputTileShapeDesc( + int column_, + int row_, + int group_, + int cluster_, + int tile_ + ): + column(column_), + row(row_), + group(group_), + cluster(cluster_), + tile(tile_) { } + + /// Total number of points in the 5D space + CUTLASS_HOST_DEVICE + int count() const { + return column * row * group * cluster * tile; + } + + #if 0 + CUTLASS_HOST_DEVICE + void print() const { + printf("{%d, %d, %d, %d, %d}", column, row, group, cluster, tile); + } + #endif +}; + +/// Helper template to construct an OutputTileShapeDesc from a OutputTileShape template. +template +CUTLASS_HOST_DEVICE +OutputTileShapeDesc make_OutputTileShapeDesc() { + return OutputTileShapeDesc( + Shape::kColumn, + Shape::kRow, + Shape::kGroup, + Shape::kCluster, + Shape::kTile + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread map description +struct OutputTileThreadMapDesc { + + int threads; + int elements_per_access; + OutputTileShapeDesc shape; + OutputTileShapeDesc iterations; + OutputTileShapeDesc delta; + OutputTileShapeDesc count; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + OutputTileThreadMapDesc() { } + + CUTLASS_HOST_DEVICE + OutputTileThreadMapDesc( + int threads_, + int elements_per_access_, + OutputTileShapeDesc shape_, + OutputTileShapeDesc iterations_, + OutputTileShapeDesc delta_, + OutputTileShapeDesc count_ + ): + threads(threads_), + elements_per_access(elements_per_access_), + shape(shape_), + iterations(iterations_), + delta(delta_), + count(count_) + { + + } +}; + +/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. +template +CUTLASS_HOST_DEVICE +OutputTileThreadMapDesc make_OutputTileThreadMapDesc() { + return OutputTileThreadMapDesc( + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + make_OutputTileShapeDesc(), + make_OutputTileShapeDesc(), + make_OutputTileShapeDesc(), + make_OutputTileShapeDesc() + ); +} +/////////////////////////////////////////////////////////////////////////////// + +// +// Parameters struct for PredicatedTileIterator +// + +struct PredicatedTileIteratorParams { + + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + + LongIndex stride; ///< stride in bytes between rows + + LongIndex increment_row; ///< increment quantity (in bytes) to advance when moving between rows + LongIndex increment_group; ///< increment quantity (in bytes) to advance when moving to the next group + LongIndex increment_cluster; ///< increment quantity (in bytes) to advance when moving to the next cluster + + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_group; ///< amount to add to move to the next 'group' position + LongIndex advance_cluster; ///< amount to add to move to the next 'cluster' position + LongIndex advance_tile; ///< amount to add to move to the next 'tile' + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride_, OutputTileThreadMapDesc thread_map) { + + stride = stride_; + + increment_row = stride * thread_map.delta.row; + + increment_group = stride * thread_map.delta.group + - stride * thread_map.delta.row * (thread_map.iterations.row - 1); + + increment_cluster = stride * thread_map.delta.cluster + - stride * thread_map.delta.group * (thread_map.iterations.group - 1) + - stride * thread_map.delta.row * (thread_map.iterations.row - 1); + + advance_row = stride * thread_map.shape.row; + + advance_group = + stride * + (thread_map.shape.group - 1) * thread_map.shape.row * thread_map.count.row; + + advance_cluster = + stride * + thread_map.count.group * + thread_map.shape.group * + thread_map.count.row * + thread_map.shape.row; + + advance_tile = + stride * + thread_map.shape.group * + thread_map.shape.row * + thread_map.shape.cluster * + thread_map.shape.tile; + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Status initialize(Index stride_, OutputTileThreadMapDesc thread_map) { + return initialize(LongIndex(stride_), thread_map); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorParams() { + initialize(LongIndex(0), OutputTileThreadMapDesc()); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorParams(Index stride, OutputTileThreadMapDesc thread_map) { + initialize(stride, thread_map); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorParams(LongIndex stride, OutputTileThreadMapDesc thread_map) { + initialize(stride, thread_map); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +// +// Parameters struct for PredicatedTileIteratorDirect2dConv +// + +struct PredicatedTileIteratorDirect2dConvParams{ + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + FastDivmod pq_divmod; + FastDivmod q_divmod; + + LongIndex stride; + LongIndex stride_n; + LongIndex stride_p; + + int N; + int P; + int Q; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride_, + cutlass::conv::Conv2dProblemSize const &problem_size, + MatrixCoord threadblock_output_shape) { + stride = stride_; // The stride per row of output tensor (bytes) + stride_n = problem_size.P * problem_size.Q; + stride_p = problem_size.Q ; + + N = problem_size.N; + P = problem_size.P; + Q = problem_size.Q; + + // Fastdivmod for output O, P, Q + if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){ + // MSVC emits a "potential divide by 0" warning as error + // if the code just divides without a check and substitution. + + CUTLASS_ASSERT(threadblock_output_shape.row() != 0); + const auto row_denom = threadblock_output_shape.row() != 0 ? + threadblock_output_shape.row() : cutlass::MatrixCoord::Index(1); + int tiles_p = + (problem_size.P + (threadblock_output_shape.row() - 1)) / row_denom; + + CUTLASS_ASSERT(threadblock_output_shape.column() != 0); + const auto col_denom = threadblock_output_shape.column() != 0 ? + threadblock_output_shape.column() : cutlass::MatrixCoord::Index(1); + int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) / + col_denom; + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + } + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Status initialize( + Index stride_, + cutlass::conv::Conv2dProblemSize const &problem_size = cutlass::conv::Conv2dProblemSize(), + MatrixCoord threadblock_output_shape = MatrixCoord()) { + return initialize(LongIndex(stride_), problem_size, threadblock_output_shape); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirect2dConvParams() { initialize(LongIndex(0)); } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirect2dConvParams(Index stride, + cutlass::conv::Conv2dProblemSize const &problem_size, + MatrixCoord threadblock_output_shape) { + initialize(stride, problem_size, threadblock_output_shape); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirect2dConvParams(LongIndex stride, + cutlass::conv::Conv2dProblemSize const &problem_size, + MatrixCoord threadblock_output_shape) { + initialize(stride, problem_size, threadblock_output_shape); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// InterleavedPredicatedTileIterator +/////////////////////////////////////////////////////////////////////////////// + + +/// Predicated tile access iterator descriptor object containing template dependent state +struct InterleavedPredicatedTileIteratorDesc { + + int element_size_bits; + int elements_per_access; + int threadmap_warp_size; + layout::PitchLinearCoord threadmap_iterations; + layout::PitchLinearCoord threadmap_delta; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorDesc() { } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorDesc( + int element_size_bits_, + int elements_per_access_, + int threadmap_warp_size_, + layout::PitchLinearCoord threadmap_iterations_, + layout::PitchLinearCoord threadmap_delta_ + ): + element_size_bits(element_size_bits_), + elements_per_access(elements_per_access_), + threadmap_warp_size(threadmap_warp_size_), + threadmap_iterations(threadmap_iterations_), + threadmap_delta(threadmap_delta_) { } +}; + +// +// Parameters struct InterleavedPredicatedTileIterator +// + +struct InterleavedPredicatedTileIteratorParams { + + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + + LongIndex stride; ///< stride in bytes between rows + LongIndex advance_row; ///< amount to add to move to the next 'row' position + LongIndex advance_column; ///< amount to add to move to the next 'column' position + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride_, InterleavedPredicatedTileIteratorDesc desc) { + + stride = stride_; + + advance_row = desc.threadmap_delta.contiguous() * desc.element_size_bits / 8; + + advance_column = stride_ - desc.threadmap_iterations.contiguous() * + desc.elements_per_access * + desc.element_size_bits * + desc.threadmap_warp_size / 8; + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorParams() { + initialize(LongIndex(0), InterleavedPredicatedTileIteratorDesc()); + } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorParams(Index stride, InterleavedPredicatedTileIteratorDesc desc) { + initialize(stride, desc); + } + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorParams(LongIndex stride, InterleavedPredicatedTileIteratorDesc desc) { + initialize(stride, desc); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper template to construct an OutputTileShapeDesc from a OutputTileThreadMap template. +template +CUTLASS_HOST_DEVICE +InterleavedPredicatedTileIteratorDesc make_InterleavedPredicatedTileIteratorDesc() { + return InterleavedPredicatedTileIteratorDesc( + sizeof_bits::value, + ThreadMap::kElementsPerAccess, + ThreadMap::kWarpSize, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper template to construct an MakePredicatedTileIteratorDesc from a template +// dependent state +template + struct MakePredicatedTileIteratorDesc; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for layout::RowMajor output data. +template +struct MakePredicatedTileIteratorDesc < + Element, layout::RowMajor, ThreadMap> { + + CUTLASS_HOST_DEVICE + OutputTileThreadMapDesc operator()() { + + return make_OutputTileThreadMapDesc(); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for layout::ColumnMajorInterleaved output data. +template +struct MakePredicatedTileIteratorDesc < + Element, layout::ColumnMajorInterleaved, ThreadMap> { + + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIteratorDesc operator()() { + + return make_InterleavedPredicatedTileIteratorDesc(); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h new file mode 100644 index 0000000000000000000000000000000000000000..a4ed371f4d9d22f205306fb43253f5021168b003 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_predicates.h @@ -0,0 +1,309 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief PredicatedTileIteratorPredicates. + + PredicatedTileIteratorPredicates enables both upper and lower bounds for predicates. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator predicates used to bound computations in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_ ///< Element data type +> +class PredicatedTileIteratorPredicates { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { + + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index lower_extent_row_; + Index upper_extent_row_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(lower_extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(upper_extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPredicates( + PredicatedTileIteratorParams const & params, + TensorCoord lower_extent, + TensorCoord upper_extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + lower_extent_row_ = lower_extent.row(); + upper_extent_row_ = upper_extent.row(); + thread_start_row_ = thread_offset.row(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < upper_extent.column()) && + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) >= lower_extent.column()); + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPredicates &operator++() { + + ++state_[0]; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Gets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } + + ///< Gets lower_extent_row_ + CUTLASS_DEVICE Index get_lower_extent_row() { + return lower_extent_row_; + } + + ///< Gets upper_extent_row_ + CUTLASS_DEVICE Index get_upper_extent_row() { + return upper_extent_row_; + } + + ///< Gets thread_start_row_ + CUTLASS_DEVICE Index get_thread_start_row() { + return thread_start_row_; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h new file mode 100644 index 0000000000000000000000000000000000000000..dfe9571e72bafe38b8877d64106e3dda6c0d93d3 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h @@ -0,0 +1,479 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_ ///< Element data type +> +class PredicatedTileIteratorStridedDgrad { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + + /// Convolution problem size + cutlass::conv::Conv2dProblemSize problem_size; + int tiled_rows_per_filter; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, cutlass::conv::Conv2dProblemSize problem_size_, int threadblock_row): + problem_size(problem_size_), + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_row); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_row; + } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Starting Dx h and w dimension for strided dgrad mapping + int start_h_, start_w_; + + /// Effective Dy P and Q dimensions for strided dgrad mapping + int p_, q_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have been computed) + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorStridedDgrad( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + int r = start_r; + int s = start_s; + + if (params_.problem_size.mode == cutlass::conv::Mode::kConvolution) { + r = (params_.problem_size.R - 1 - r); + s = (params_.problem_size.S - 1 - s); + } + + // compute starting coordinates in Dx start_h_ and start_w_ + strided_dgrad_starting_coords( + params_.problem_size, + stride_h_divmod, stride_w_divmod, + r, s, + start_h_, start_w_); + + p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h; + q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w; + + extent_row_ = extent.row(); + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + // remapping rows to find the mapped_row_offset + int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; + + // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] + int n = npq_offset / (p_ * q_); + int residual = npq_offset % (p_ * q_); + int p = residual / q_; + int q = residual % q_; + + int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + + (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + + (start_w_ + q * params_.problem_size.stride_w); + bool row_guard = mapped_row_offset < extent_row_; + + int64_t row_byte_offset = mapped_row_offset * params_.stride; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), + guard); + } + } + } + } + } + + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + // remapping rows to find the mapped_row_offset + int npq_offset = (row_offset + thread_start_row_) % params_.tiled_rows_per_filter; + + // (STEP 4.a) [order NHW rows to be loaded and stored in output Dx NHWxC layout] + int n = npq_offset / (p_ * q_); + int residual = npq_offset % (p_ * q_); + int p = residual / q_; + int q = residual % q_; + + int mapped_row_offset = n * (params_.problem_size.H * params_.problem_size.W) + + (start_h_ + p * params_.problem_size.stride_h) * params_.problem_size.W + + (start_w_ + q * params_.problem_size.stride_w); + bool row_guard = mapped_row_offset < extent_row_; + + int64_t row_byte_offset = mapped_row_offset * params_.stride; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int64_t column_byte_offset = (thread_start_column_ + column * ThreadMap::Delta::kColumn) * (sizeof_bits::value / 8); + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)(byte_pointer + row_byte_offset + column_byte_offset + byte_offset), + guard); + } + } + } + } + } + + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + + store_with_byte_offset(frag, 0); + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorStridedDgrad &operator++() { + + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..a321f1b61b3364d2c6450604b14822a2dc560a26 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator.h @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + int MaxAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8 +> +class SharedLoadIterator { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::TileShape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kMinAlignment = ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + ThreadMap::kElementsPerAccess, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment) + >; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + + // + // Data members + // + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Stride along adjacent rows + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIterator( + TensorRef ref, + int thread_idx + ): + byte_pointer_(reinterpret_cast(ref.data())), + stride_((ref.stride(0) * sizeof_bits::value) / 8) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointer + byte_pointer_ += + thread_offset.row() * stride_ + + thread_offset.column() * sizeof(AccessType) / kElementsPerAccess; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + byte_pointer_ += + offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + uint8_t const *byte_pointer = byte_pointer_ + + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset * sizeof_bits::value / 8; + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + LoadType const *memory_pointer = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + frag_ptr[frag_idx * kLoadsPerAccess + v] = + memory_pointer[(column * ThreadMap::Delta::kColumn / kElementsPerAccess) * kLoadsPerAccess + v]; + } + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void set_smem_base_address(Index address) { + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h new file mode 100644 index 0000000000000000000000000000000000000000..66cc17f72817d1feb4d8eb6c6242c1e8efb5ce2e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -0,0 +1,594 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops optimized for mixed-precision. + + This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. + + When the fragment is loaded into registers, it matches the row-major thread map assumed by + the predicated tile iterator writing to global memory. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Accumulator data type + int ElementSizeBits_, ///< Size of accumulator in bits + int OutputSizeBits_, ///< Size of output element in bits + int ElementsPerAccess, ///< Vector length of output vector + int ContiguousLanes, ///< Number of lanes in the warp writing to contiguous elements + /// in the global memory tensor + bool EightBitsOutputOrLess = (OutputSizeBits_ <= 8) +> +class SharedLoadIteratorMixed; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_ ///< Accumulator data type +> +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + ThreadMap::kElementsPerAccess, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + const_min(128 / sizeof_bits::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment) + >; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + + // + // Data members + // + + /// Byte-level pointer + LoadType const *pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed( + TensorRef ref, + int thread_idx + ): + stride_((ref.stride(0) / LoadType::kElements)) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += + offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const *memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Set base smem address + CUTLASS_DEVICE + void set_smem_base_address(Index address) {} + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for +/// int32_t x 16 => int8_t/int4b_t x 16 and +/// float x 16 => float_e4m3_t/float_e5m2_t x 16 +template < + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) + typename Element_, + int OutputSizeBits_ ///< Size of output element in bits +> +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + static_assert(sizeof_bits::value == 32, "Element size in bits must be 32."); + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = 16; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + 16, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + 4, + 16 + >; + + static int const kLoadsPerAccess = 4; + +private: + + // + // Data members + // + + /// Byte-level pointer + LoadType const *pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed( + TensorRef ref, + int thread_idx + ): + stride_((ref.stride(0) / LoadType::kElements)) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; + + int lane_col_idx = thread_offset.column() / 16; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + int lane_offset = (lane_col_idx % 2) * 4 | ((lane_col_idx / 2) * 8) | ((lane_col_idx / 2) ^ i); + + pointers_[i] = base_ptr + lane_offset; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += + offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + LoadType const *memory_pointer = pointers_[v]; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; + } + } + } + } + } + } + + /// Set base smem address + CUTLASS_DEVICE + void set_smem_base_address(Index address) {} + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for: +/// int32_t x 8 => int8_t/int4b_t x 8 and +/// float x 8 => float_e4m3_t/float_e5m2_t x 8 +template < + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) + typename Element_, + int OutputSizeBits_ +> +class SharedLoadIteratorMixed { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + static_assert(sizeof_bits::value == 32, "Element size in bits must be 32."); + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray< + Element, + 8, + kAlignment>; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray< + Element, + 4, + 16 + >; + + static int const kLoadsPerAccess = 2; + +private: + + // + // Data members + // + + /// Byte-level pointer + LoadType const *pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed( + TensorRef ref, + int thread_idx + ): + stride_((ref.stride(0) / LoadType::kElements)) { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + LoadType const *base_ptr = reinterpret_cast(ref.data()) + thread_offset.row() * stride_; + + int lane_col_idx = thread_offset.column() / 8; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + int lane_offset = (lane_col_idx % 8) * 2 | ((lane_col_idx / 4) ^ i); + + pointers_[i] = base_ptr + lane_offset; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += + offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int row_ptr_offset = + row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup* stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + + LoadType const *memory_pointer = pointers_[v]; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[row_ptr_offset]; + } + } + } + } + } + } + + /// Set base smem address + CUTLASS_DEVICE + void set_smem_base_address(Index address) {} + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h new file mode 100644 index 0000000000000000000000000000000000000000..74d040ba0be731c2a3faa46a1a4034ed9eccb9e2 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. + + When the fragment is loaded into registers, it matches the row-major thread map assumed by + the predicated tile iterator writing to global memory. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template ::value / 8> +class SharedLoadIteratorPitchLinear { + public: + using ThreadMap = ThreadMap_; + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kMinAlignment = + ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = + AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + + private: + // + // Data members + // + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Stride along adjacent rows + int stride_; + + /// Base address offset + Index base_smem_address_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorPitchLinear(TensorRef ref, int thread_idx) + : byte_pointer_(reinterpret_cast(ref.data())), + stride_((ref.stride(0) * sizeof_bits::value) / 8), + base_smem_address_(0) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointer + // thread_offset.row() is contiguous dim + // thread_offset.column() is stride dim + byte_pointer_ += thread_offset.row() * sizeof(AccessType) / kElementsPerAccess+ + thread_offset.column() * stride_ ; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + byte_pointer_ += + offset.row() * ThreadMap::StorageShape::kContiguous * sizeof(AccessType) / kElementsPerAccess + + offset.column() * ThreadMap::StorageShape::kStrided * stride_; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + uint8_t const *byte_pointer = + byte_pointer_ + s * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous * ThreadMap::kElementsPerAccess * + sizeof_bits::value / 8 + + pointer_offset * sizeof_bits::value / 8 + base_smem_address_; + + int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; + + LoadType *frag_ptr = reinterpret_cast(&frag); + + LoadType const *memory_pointer = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + frag_ptr[frag_base_idx * kLoadsPerAccess + v] = memory_pointer[v]; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void set_smem_base_address(Index address) { base_smem_address_ = address; } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..58ccbfacf504b28da2282dc69214b149acda3c65 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FragmentIteratorComplexTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type + typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) +> +class FragmentIteratorComplexTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + complex, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + static int const kRealIndex = 0; + + /// Offset into the accumulator fragment + static int const kImaginaryIndex = + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array; + + /// This is the complete warp-level accumulator tile. + using OutputAccumulatorTile = Array, kImaginaryIndex>; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + + using FragmentAccessType = Array, Policy::kElementsPerAccess>; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorComplexTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorComplexTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorComplexTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + auto const & real_accum_array = accumulators_[accumulator_access_offset + kRealIndex]; + auto const & imag_accum_array = accumulators_[accumulator_access_offset + kImaginaryIndex / Policy::kElementsPerAccess]; + + // Pack real and imaginary parts into a structure. This is likely to result in MOVs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Policy::kElementsPerAccess; ++i) { + + frag_ptr[n][i].real() = real_accum_array[i]; + frag_ptr[n][i].imag() = imag_accum_array[i]; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b03cab835c7f137db1f923cf393007fbfaa7ed1e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FragmentIteratorGaussianComplexTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< underlying real-valued matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< underlying real-valued matrix multiply operation data type + typename OperatorFragmentC_ ///< underlying real-valued matrix multiply operation fragment (concept: Array) +> +class FragmentIteratorGaussianComplexTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + complex, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// Size of one part of accumulator of 3-part accumulator in units of number of OperatorElementC + static int const kElementsAccumulatorPerPart = + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn; + + /// Offset into the accumulator fragment part 1 + static int const kPart1Index = kElementsAccumulatorPerPart * 0; + + /// Offset into the accumulator fragment part 2 + static int const kPart2Index = kElementsAccumulatorPerPart * 1; + + /// Offset into the accumulator fragment part 3 + static int const kPart3Index = kElementsAccumulatorPerPart * 2; + + /// This is the complete warp-level accumulator tile holding part1, part2, and part3 + using AccumulatorTile = Array; + + /// This is the complete warp-level accumulator tile holding final output of complex type + using OutputAccumulatorTile = Array, kElementsAccumulatorPerPart>; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + + using FragmentAccessType = Array, Policy::kElementsPerAccess>; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorGaussianComplexTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorGaussianComplexTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorGaussianComplexTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + auto const & part1_accum_array = accumulators_[accumulator_access_offset + kPart1Index]; + auto const & part2_accum_array = accumulators_[accumulator_access_offset + kPart2Index / Policy::kElementsPerAccess]; + auto const & part3_accum_array = accumulators_[accumulator_access_offset + kPart3Index / Policy::kElementsPerAccess]; + + // Pack parts 1, 2, and 3 into a structure. This is likely to result in MOVs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Policy::kElementsPerAccess; ++i) { + + frag_ptr[n][i].real() = part1_accum_array[i] - part3_accum_array[i]; + frag_ptr[n][i].imag() = part1_accum_array[i] + part2_accum_array[i]; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h new file mode 100644 index 0000000000000000000000000000000000000000..404be79f3ba894a90fbd3b6fa8ec56ac1717ff4b --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_simt.h @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/simt_policy.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fragment iterator for SIMT accumulator arrangements +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename Operator, ///< matrix multiply operation (concept: arch::Mma) + typename Layout, ///< target shared memory layout + typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +class FragmentIteratorSimt; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename Operator_ , ///< matrix multiply operator (concept: arch::Mma) + typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +class FragmentIteratorSimt { +public: + + using WarpShape = WarpShape_; + using Operator = Operator_; + using Layout = layout::RowMajor; + + /// Policy for warp-level epilogue components + using Policy = SimtPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + typename Operator::ElementC, + Policy::kElementsPerIteration>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + typename Operator::ElementC, + Policy::kAccumulatorElementCount>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorSimt(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorSimt &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorSimt &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + + int accumulator_access_offset = index_ * Policy::kAccessesPerIteration + n; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c6f10b0e694bcc142b60d39e242d9192482d566 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FragmentIteratorTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) +> +class FragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + using TileIterations = typename Policy::TileIterations; + static int const kIterationsPerTile = kIterations / TileIterations::kCount; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for col-major shared memory +/// Only works for 168x tensor core kernels +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) +> +class FragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::ColumnMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + 4 * Policy::OperatorCount::kRow * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + using TileIterations = typename Policy::TileIterations; + static int const kIterationsPerTile = kIterations / TileIterations::kCount; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Policy::kAccumulatorRowStride; ++i) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < (Policy::OperatorCount::kRow * 2); ++m) { + + int accumulator_access_offset = + index * Policy::kAccumulatorColumnStride + m * Policy::kAccumulatorRowStride / Policy::kElementsPerAccess + i; + + frag_ptr[m + i * Policy::OperatorCount::kRow * 2] = accumulators_[accumulator_access_offset]; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Dedicated to interleaved layout +template < + /// shape of the warp-level GEMM tile + typename WarpShape_, + /// matrix multiply operator shape (concept: gemm::GemmShape) + typename OperatorShape_, + /// matrix multiply operator data type (concept: data type) + typename OperatorElementC_, + /// matrix multiply operator fragment (concept: Array) + typename OperatorFragmentC_, + /// number of interleaved k + int InterleavedK> +class FragmentIteratorTensorOp> { + public: + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = + Array; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = + Array; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + using TileIterations = typename Policy::TileIterations; + static int const kIterationsPerTile = kIterations / TileIterations::kCount; + + private: + /// Internal access type + using AccessType = + Array; + + private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp(AccumulatorTile const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0) {} + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < (InterleavedK / OperatorShape::kN); ++n) { + int index_m = index % (Policy::OperatorCount::kRow * + Policy::kIterationsPerInstruction); + int index_n = index / (Policy::OperatorCount::kRow * + Policy::kIterationsPerInstruction); + int accumulator_access_offset = + (index_m / Policy::kIterationsPerInstruction) * + (Policy::OperatorCount::kColumn * + Policy::kIterationsPerInstruction) + + (index_m % Policy::kIterationsPerInstruction) + + index_n * (InterleavedK / OperatorShape::kN) * + Policy::kIterationsPerInstruction + + n * Policy::kIterationsPerInstruction; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fede55860c5aa6a24dce06f1b065d2711eef49a4 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_volta_tensor_op.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) + typename ElementC, ///< Accumulator layout + typename Layout ///< target shared memory layout +> +class FragmentIteratorVoltaTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) +> +class FragmentIteratorVoltaTensorOp, half_t, layout::RowMajor> { +public: + + using WarpShape = WarpShape_; + using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; + using ElementC = half_t; + using Layout = layout::RowMajor; + + /// Policy operator + using Policy = VoltaTensorOpPolicy; + + /// Array type for aligned memory accesses + using AccessType = typename Policy::AccessType; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = typename Policy::Fragment; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = typename Policy::AccumulatorTile; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorVoltaTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorVoltaTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + static int const kAccessesPerMma = Policy::kElementsPerMma / Policy::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + + int tile_access_idx = + (tile_n * Policy::TileIterations::kRow + (index_ & 2) / 2) * Policy::MmaIterations::kCount * kAccessesPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * kAccessesPerMma; ++mma_n) { + + int mma_access_idx = ((mma_n & 1) * 2 + (index_ & 1)) * kAccessesPerMma + (mma_n & 2) / 2; + + frag_ptr[tile_n * Policy::MmaIterations::kColumn * kAccessesPerMma + + mma_n] = accumulators_[tile_access_idx + mma_access_idx]; + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) +> +class FragmentIteratorVoltaTensorOp, float, layout::RowMajor> { +public: + + using WarpShape = WarpShape_; + using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; + using ElementC = float; + using Layout = layout::RowMajor; + + /// Policy operator + using Policy = VoltaTensorOpPolicy; + + /// Array type for aligned memory accesses + using AccessType = typename Policy::AccessType; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = typename Policy::Fragment; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = typename Policy::AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorVoltaTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorVoltaTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + int const kRegsPerMmaRow = 2; + + CUTLASS_PRAGMA_UNROLL + for (int reg_row = 0; reg_row < Policy::kRowsPerMmaTile; ++reg_row) { + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * 2; ++mma_n) { + + int mma_idx = (index_ & 1) + (index_ & 2) * Policy::MmaIterations::kCount / 2 + + (tile_n * Policy::TileIterations::kRow) * Policy::MmaIterations::kCount + (mma_n & 1) * 2; + + int reg_offset = reg_row * kRegsPerMmaRow + (mma_n & 2) * 2; + int reg_idx = mma_idx * Policy::kElementsPerMma + reg_offset; + + *frag_ptr = accumulators_[reg_idx / Policy::kElementsPerAccess]; + ++frag_ptr; + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..245499b02e2758be4d0a8998650a94cffa92112e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/wmma_array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/wmma_tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: nvcuda::cuda::fragment) + typename Layout ///< target shared memory layout +> +class FragmentIteratorWmmaTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: nvcuda::cuda::fragment) +> +class FragmentIteratorWmmaTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = WmmaTensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = WmmaFragmentArray; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = WmmaFragmentArray; + + using OutputAccumulatorTile = AccumulatorTile; + +private: + + /// Internal access type + using AccessType = WmmaFragmentArray; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorWmmaTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorWmmaTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorWmmaTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for(int n=0; n < Policy::OperatorCount::kColumn; n++) { + + int accumulator_access_offset = index_ * Policy::OperatorCount::kColumn + n; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } +}; + + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/simt_policy.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/simt_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..a1fa65ca57aa2599c4321202a9ee9dca5ffef3a6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/simt_policy.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. + These quantities assume a 'column-major' arrangement of SimtOp instructions, of which + a row-oriented slice is visible per iteration. +*/ + +#pragma once + +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: GemmShape) + typename Operator, ///< matrix multiply operation (concept: arch::Mma) + typename Layout, ///< destination layout in shared memory + typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +struct SimtPolicy; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: MatrixShape) + typename Operator_, ///< matrix multiply operation (concept: arch::Mma) + typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +struct SimtPolicy { + + using WarpShape = WarpShape_; + using Operator = Operator_; + using MmaSimtPolicy = MmaSimtPolicy_; + + static_assert(!(WarpShape::kM % MmaSimtPolicy::WarpShape::kRow), "Divisibility"); + static_assert(!(WarpShape::kN % MmaSimtPolicy::WarpShape::kColumn), "Divisibility"); + + /// Number of iterations + static int const kIterations = WarpShape::kM / MmaSimtPolicy::WarpShape::kRow; + + /// Number of accumulators written per iteration + static int const kElementsPerIteration = + (WarpShape::kN / MmaSimtPolicy::WarpShape::kColumn); + + /// Total number of accumulators + static int const kAccumulatorElementCount = kElementsPerIteration * kIterations; + + /// Number of consecutive elements + static int const kElementsPerAccess = MmaSimtPolicy::LaneMmaShape::kN; + + /// Number of rows per epilogue iteration + static int const kRowsPerIteration = MmaSimtPolicy::WarpShape::kRow; + + /// Number of accesses made in one iteration + static int const kAccessesPerIteration = kElementsPerIteration / kElementsPerAccess; + + /// Number of elements in between accumulator chunks of (LaneMmaShape::kM x LaneMmaShape::kN) + using Delta = MatrixShape< + MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM, + MmaSimtPolicy::WarpShape::kColumn * MmaSimtPolicy::LaneMmaShape::kN + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..002d8591e19041f22d9c105b85caa51538540f4a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tensor_op_policy.h @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. + These quantities assume a 'column-major' arrangement of TensorOp instructions, of which + a row-oriented slice is visible per iteration. +*/ + +#pragma once + +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy details related to the epilogue +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm:GemmShape) + typename Layout ///< target shared memory layout +> +struct TensorOpPolicy; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +struct TensorOpPolicy { + + /// Number of operations + using OperatorCount = MatrixShape< + (WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM, + (WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN + >; + + // + // Hard-coded constants regarding Tensor Operations + // + + static int const kElementsPerAccess = 2; + static int const kRowsPerIteration = 8; + static bool const kDivisible = + !(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN); + + // + // Derived quantities + // + + // Number of 'externally visible' iterations per actual instruction + static int const kIterationsPerInstruction = OperatorShape::kM / kRowsPerIteration; + + // Number of externally visible iterations + static int const kIterations = OperatorCount::kRow * kIterationsPerInstruction; + + using TileIterations = MatrixShape; + + static int const kAccumulatorRowStride = kElementsPerAccess; + static int const kAccumulatorColumnStride = kElementsPerAccess * OperatorCount::kRow * kIterationsPerInstruction; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +struct TensorOpPolicy { + + /// Number of operations + using OperatorCount = MatrixShape< + (WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM, + (WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN + >; + + // + // Hard-coded constants regarding Tensor Operations + // + + static int const kElementsPerAccess = 1; + static int const kColumnsPerIteration = 8; + static bool const kDivisible = + !(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN); + + // + // Derived quantities + // + + // Number of 'externally visible' iterations per actual instruction + static int const kIterationsPerInstruction = OperatorShape::kN / kColumnsPerIteration; + + // Number of externally visible iterations + static int const kIterations = OperatorCount::kColumn * kIterationsPerInstruction; + + using TileIterations = MatrixShape; + + // Hard code for 16x8 + static int const kAccumulatorRowStride = 2; + static int const kAccumulatorColumnStride = 4 * OperatorCount::kRow; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major-interleaved +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation (concept: arch::Mma) + int InterleavedK ///< number of interleaved k + > +struct TensorOpPolicy > { + /// Number of operations + using OperatorCount = MatrixShape; + + // + // Hard-coded constants regarding Tensor Operations + // + + static int const kElementsPerAccess = 2; + static int const kRowsPerIteration = 8; + + // + // Derived quantities + // + + // Number of 'externally visible' iterations per actual instruction + static int const kIterationsPerInstruction = + OperatorShape::kM / kRowsPerIteration; + + // Number of externally visible iterations + static int const kIterations = WarpShape::kN / InterleavedK * + OperatorCount::kRow * + kIterationsPerInstruction; + + static int const kElementsPerIteration = InterleavedK / OperatorShape::kN * kElementsPerAccess; + + static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; + + // Number of externally visible iterations + //static int const kTileIterations = OperatorCount::kRow * kIterationsPerInstruction; + using TileIterations = MatrixShape<1, WarpShape::kN / InterleavedK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h new file mode 100644 index 0000000000000000000000000000000000000000..be7af1355fc634174dac2d15740ad94e15f60fe6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_simt.h @@ -0,0 +1,785 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/epilogue/warp/simt_policy.h" + +#define CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES 1 + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename Operator, ///< matrix multiply operation (concept: arch::Mma) + typename Element, ///< data type of element to be written + typename Layout, ///< target shared memory layout + typename MmaSimtPolicy ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +class TileIteratorSimt; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename Operator_, ///< matrix multiply operation (concept: arch::Mma) + typename Element_, ///< data type of element to be written + typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +class TileIteratorSimt { +public: + + using WarpShape = WarpShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = layout::RowMajor; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = SimtPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + typename Operator::ElementC, + Policy::kElementsPerIteration>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + typename Operator::ElementC, + Policy::kAccumulatorElementCount>; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Padding quantity + using Padding = MatrixShape< + 0, + 4 * Policy::kElementsPerAccess +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES + + 1 +#endif + >; + +private: + +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + 1 + >; + +#else + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + Policy::kElementsPerAccess + >; +#endif + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimt(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimt( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + pointer_ += layout_({ + lane_offset.row(), + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimt & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimt & add_tile_offset(TensorCoord const &tile_offset) { + + pointer_ += layout_({ + tile_offset.row() * Shape::kRow, + (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimt & operator+=(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); + + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { +#if CUTLASS_SIMT_EPILOGUE_USE_SCALAR_STORES + // de-vectorized stores + using ScalarAccessType = AlignedArray; + ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); + ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::kElementsPerAccess; s++) { + scalarPointer[n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s] = scalarFragPtr[n * Policy::kElementsPerAccess + s]; + } + } +#else + // original vector stores + AccessType const *frag_ptr = reinterpret_cast(&frag); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; + } +#endif + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template +class TileIteratorSimtDirectConv { + public: + + using WarpShape = WarpShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = layout::RowMajor; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = SimtPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Padding quantity + using Padding = MatrixShape<0, + 0 + >; + +private: + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + Policy::kElementsPerAccess + >; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Base smem offset; + Index base_smem_address_; + + public: + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv() : pointer_(nullptr) {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + pointer_ += layout_({ + lane_offset.row(), + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv & add_tile_offset(TensorCoord const &tile_offset) { + + pointer_ += layout_({ + tile_offset.row() * Shape::kRow, + (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv & operator+=(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); + + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + // original vector stores + AccessType const *frag_ptr = reinterpret_cast(&frag); + AccessType * load_pointer_ = reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + load_pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address){ + base_smem_address_ = address; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Template for reading and writing tiles of accumulators to shared memory +template +class TileIteratorSimtDirect2dConv { + public: + using WarpShape = WarpShape_; + using ThreadOutputShape = ThreadOutputShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = layout::RowMajor; + using MmaSimtPolicy = MmaSimtPolicy_; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + // Thread-level shape of a fragment + using ThreadShape = MatrixShape; + + static_assert(!(ThreadShape::kColumn % MmaSimtPolicy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + using ThreadTileCount = MatrixShape; + + using Iterations = + MatrixShape; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = typename Operator::FragmentC; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = AccumulatorTile; + + /// Padding quantity + using Padding = MatrixShape<0, 0>; + + private: + // Storage type for accessing memory + using AccessType = AlignedArray; + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Base smem offset; + Index base_smem_address_; + + public: + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirect2dConv() : pointer_(nullptr) {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtDirect2dConv(TensorRef const &ref, unsigned thread_id, unsigned lane_id) + : pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements) { + + auto lane_layout = MmaSimtPolicy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + // Get base HW offset of current threads + const int threadgroup = thread_id / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + const int base_p = (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; + const int base_q = (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; + + const int row_offset = base_p * ThreadBlockOutputShape::kW + base_q; + + pointer_ += layout_( + {row_offset, + lane_offset.column() * MmaSimtPolicy::LaneMmaShape::kN / int(AccessType::kElements)}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimtDirect2dConv &add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType *storer_pointer_ = + reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < ThreadOutputShape::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < ThreadOutputShape::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int offset = (w + h * ThreadBlockOutputShape::kW) * + (ThreadBlockOutputShape::kC / AccessType::kElements) + + col; + storer_pointer_[offset + pointer_offset / int(AccessType::kElements)] = + frag_ptr[w + h * ThreadOutputShape::kW + col]; + } + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { base_smem_address_ = address; } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename Operator_, ///< matrix multiply operation (concept: arch::Mma) + typename Element_, ///< data type of element to be written + typename Layout_, ///< target shared memory layout + typename MmaSimtPolicy_ ///< policy defining lane arrangement (concept: MmaSimtPolicy) +> +class TileIteratorSimtCanonical { +public: + + using WarpShape = WarpShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = Layout_; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = SimtPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + typename Operator::ElementC, + Policy::kElementsPerIteration>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + typename Operator::ElementC, + Policy::kAccumulatorElementCount>; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Padding quantity + using Padding = MatrixShape< + 0, + 4 * Policy::kElementsPerAccess + 1 + >; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + 1 + >; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Guard to indicate whether the shape is divisible + bool divisible_; + + /// Extent of the output tensor + MatrixCoord extent_; + + /// Thread offset + MatrixCoord thread_offset_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements), + divisible_(true), + extent_(WarpShape::kM, WarpShape::kN) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + thread_offset_ = { + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess + }; + + pointer_ += layout_({ + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical( + TensorRef const &ref, + TensorCoord const &extent, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements), + divisible_(false), + extent_(extent) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + thread_offset_ = { + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess + }; + + pointer_ += layout_({ + lane_offset.row() * Shape::kRow, + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & add_tile_offset(TensorCoord const &tile_offset) { + + MatrixCoord coord_offset( + tile_offset.row(), + tile_offset.column() * Shape::kColumn + ); + + thread_offset_ += coord_offset; + + pointer_ += layout_({ + coord_offset.row(), + coord_offset.column() + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & operator+=(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); + + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + // de-vectorized stores + using ScalarAccessType = AlignedArray; + ScalarAccessType const *scalarFragPtr = reinterpret_cast(&frag); + ScalarAccessType *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::kElementsPerAccess; s++) { + + int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; + int frag_idx = n * Policy::kElementsPerAccess + s; + + int col = thread_offset_.column() + ptr_idx; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + scalarPointer[ptr_idx] = scalarFragPtr[frag_idx]; + } + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + // de-vectorized loads + using ScalarAccessType = AlignedArray; + ScalarAccessType *scalarFragPtr = reinterpret_cast(&frag); + ScalarAccessType const *scalarPointer = reinterpret_cast(pointer_) + pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::kElementsPerAccess; s++) { + + int ptr_idx = n * Policy::MmaSimtPolicy::WarpShape::kColumn * Policy::kElementsPerAccess + s; + int frag_idx = n * Policy::kElementsPerAccess + s; + + int col = thread_offset_.column() + ptr_idx; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + scalarFragPtr[frag_idx] = scalarPointer[ptr_idx]; + } + } + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_HOST_DEVICE + TileIteratorSimtCanonical & operator++() { + return add_tile_offset({1, 0}); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7cfa072c4f8dbfb192c10a96ef776e235a7c10cf --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename Element, ///< data type of element to be written + typename Layout ///< target shared memory layout +> +class TileIteratorTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename Element_ ///< data type of element to be written +> +class TileIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using TensorLayout = Layout; + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Number of times this iterator can be incremented + using TileIterations = typename Policy::TileIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Detail::kLanesInQuad * Policy::kElementsPerAccess>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Thread offset + MatrixCoord thread_offset_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOp(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOp( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / Policy::kElementsPerAccess) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerAccess + }; + + pointer_ += layout_({thread_offset_.row(), thread_offset_.column() / Policy::kElementsPerAccess}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / Policy::kElementsPerAccess; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) { + + MatrixCoord coord_offset( + tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn + ); + + thread_offset_ += coord_offset; + + pointer_ += layout_({ + coord_offset.row(), + coord_offset.column() / Policy::kElementsPerAccess + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + pointer_[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess] = frag_ptr[n]; + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + frag_ptr[n] = pointer_[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & operator++() { + return add_tile_offset({1, 0}); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename Element_, ///< data type of element to be written + int InterleavedK ///< number of interleaved k +> +class TileIteratorTensorOp > { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = Element_; + using Layout = layout::ColumnMajorInterleaved; + using TensorLayout = Layout; ///< shared memory tensor ref layout + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< +// Policy::kRowsPerIteration, + WarpShape::kM, + InterleavedK + >; + + /// This is the fragment size produced by one tile + using Fragment = Array< + Element, + Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction + * Policy::kElementsPerIteration>; + + /// This is the fragment size produced by one iteration +// using Fragment = Array< +// Element, Policy::kElementsPerIteration >; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + using TileIterations = typename Policy::TileIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Detail::kLanesInQuad * Policy::kElementsPerIteration>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + TensorLayout layout_; + + /// Thread offset + MatrixCoord thread_offset_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOp(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOp( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0]) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerIteration + }; + + pointer_ += (layout_({thread_offset_.row(), thread_offset_.column()}) / Policy::kElementsPerAccess); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / Policy::kElementsPerAccess; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & add_tile_offset(TensorCoord const &tile_offset) { + + MatrixCoord coord_offset( + tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn + ); + + thread_offset_ += coord_offset; + + pointer_ += (layout_({ + coord_offset.row(), + coord_offset.column() + }) / Policy::kElementsPerAccess); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction; n++ ) { + + AccessType *ptr = pointer_ + layout_({n * Policy::kRowsPerIteration, 0}) / Policy::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int a = 0; a < Policy::kAccessPerIteration; ++a) { + ptr[a + pointer_offset / Policy::kElementsPerAccess] = frag_ptr[n * Policy::kAccessPerIteration + a]; + +// printf("store thread %d, address %p, bank %ld\n", threadIdx.x, pointer_+a+n*Detail::kLanesInQuad, +// ((long long)(pointer_+a+n*Detail::kLanesInQuad)>>2)&0x1f); + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kRow * Policy::kIterationsPerInstruction; n++ ) { + + AccessType *ptr = pointer_ + layout_({n * Policy::kRowsPerIteration, 0}) / Policy::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int a = 0; a < Policy::kAccessPerIteration; ++a) { + frag_ptr[n * Policy::kAccessPerIteration + a] = ptr[a + pointer_offset / Policy::kElementsPerAccess]; + } + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_HOST_DEVICE + TileIteratorTensorOp & operator++() { + return add_tile_offset({0, 1}); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename Element_, ///< data type of element to be written + typename Layout_ +> +class TileIteratorTensorOpCanonical { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = Element_; + using Layout = Layout_; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + static int const kAccessSize = 1; + static int const kAccessCount = Policy::kElementsPerAccess / kAccessSize; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Detail::kLanesInQuad * Policy::kElementsPerAccess>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Guard to indicate whether the shape is divisible + bool divisible_; + + /// Extent of the output tensor + MatrixCoord extent_; + + /// Thread offset + MatrixCoord thread_offset_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0]), + divisible_(true), + extent_(WarpShape::kM, WarpShape::kN) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerAccess + }; + + pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical( + TensorRef const &ref, + TensorCoord const &extent, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0]), + divisible_(false), + extent_(extent) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerAccess + }; + + pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & add_tile_offset(TensorCoord const &tile_offset) { + + MatrixCoord coord_offset( + tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn + ); + + thread_offset_ += coord_offset; + + pointer_ += layout_({ + coord_offset.row(), + coord_offset.column() + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int a = 0; a < kAccessCount; ++a) { + + int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; + int frag_idx = n * kAccessCount + a; + + int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + pointer_[ptr_idx] = frag_ptr[frag_idx]; + } + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int a = 0; a < kAccessCount; ++a) { + + int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; + int frag_idx = n * kAccessCount + a; + + int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + frag_ptr[frag_idx] = pointer_[ptr_idx]; + } + } + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & operator++() { + return add_tile_offset({1, 0}); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h new file mode 100644 index 0000000000000000000000000000000000000000..134e668606dc79589f49e38b16fd06d14e97e27d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -0,0 +1,1089 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is an optimization available on CUDA 11.2 and beyond that eliminates branches in the epilogue. +#define CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED ((__CUDACC_VER_MAJOR__ * 10 + __CUDACC_VER_MINOR__) >= 112) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory. This is optimized +/// for mixed-precision epilogues in which the accumulators are 32b in width, but the output +/// data type is smaller. +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename Element_, ///< data type of accumulator element + int ElementSizeBits, ///< Size of accumulator element in bits + int OutputSizeBits, ///< Size of output element in bits + int OutputElementCount, ///< number of elements in output vector + int ContiguousLanes, ///< Number of consecutive lanes writing to contiguous memory + bool EightBitsOutputOrLess = (OutputSizeBits <= 8) +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kOutputElementCount = OutputElementCount; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = + (OutputElementCount * sizeof_bits::value) / (const_min(128, OutputElementCount * sizeof_bits::value)); + + // Currently support max 4 ptr + static constexpr int kMaxPointerCount{4}; + + static_assert(kPointerCount <= kMaxPointerCount, "Can only accommodate four pointers at present."); + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Detail::kLanesInQuad * Policy::kElementsPerAccess>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount] = {nullptr}; + + /// Stride in units of AccessType + int stride_{0}; + + /// Logical column in which warp tile is aligned + int warp_column_{0}; + +public: + + /// Default constructor + TileIteratorTensorOpMixed() = default; + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / Policy::kElementsPerAccess), + warp_column_(0) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = (lane_in_quad % 2) + (((lane_in_quad / 2) + i) % Detail::kPointerCount) * 2; + + ptr += column_idx; + + pointers_[i % Detail::kPointerCount] = ptr; + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / Policy::kElementsPerAccess; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess; + } + + warp_column_ += tile_offset.column() * Shape::kColumn; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + AccessType *ptr = pointers_[0]; + +#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED + + // When the optimization is enabled, small tiles require separate logic. + bool kN32_optimization = (WarpShape::kN * Detail::kLanesInQuad * Policy::kElementsPerAccess * sizeof_bits::value) % 1024 == 0; + if (kN32_optimization) { + + int ptr_idx = ((warp_column_ * sizeof_bits::value) / 1024) % Detail::kPointerCount; + + if (ptr_idx == 0) { + ptr = pointers_[0]; + } else if (ptr_idx == 1) { + if constexpr (AccessType::kElements >= 2) { + ptr = pointers_[1]; + } + } else if (ptr_idx == 2) { + if constexpr (AccessType::kElements >= 3) { + ptr = pointers_[2]; + } + } else if (ptr_idx == 3) { + if constexpr (AccessType::kElements >= 4) { + ptr = pointers_[3]; + } + } + } + +#endif + + CUTLASS_PRAGMA_UNROLL + for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { + +#if CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED + + // + // When the optimization is enabled, this expression suffices to obtain the SMEM pointer. + // + if (WarpShape::kN == 64) { + ptr = pointers_[n / 4]; + } + else if (!kN32_optimization) +#endif + { + // This is the reference implementation + int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; + int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; + + if (ptr_idx == 0) { + ptr = pointers_[0 % Detail::kPointerCount]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1 % Detail::kPointerCount]; + } + else if (ptr_idx == 2) { + ptr = pointers_[2 % Detail::kPointerCount]; + } + else if (ptr_idx == 3) { + ptr = pointers_[3 % Detail::kPointerCount]; + } + } + + int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess; + ptr[offset] = frag_ptr[n]; + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int64_t n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int column_idx = warp_column_ + n * Detail::kLanesInQuad * Policy::kElementsPerAccess; + int ptr_idx = ((column_idx * sizeof_bits::value) / 1024) % Detail::kPointerCount; + + AccessType const *smem_ptr = pointers_[ptr_idx]; + frag_ptr[n] = smem_ptr[n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess]; + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for int32_t x 16 => int8_t/int4b_t x 16 +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape), + int OutputSizeBits ///< Size of output element in bits +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = int32_t; + using Layout = layout::RowMajor; + static int const kOutputElementCount = 16; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = 2; + + /// Offsets added + static int const kOffsetCount = 4; + + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount] = {nullptr}; + + /// Stride in units of AccessType + int stride_{0}; + + /// Uniform offset in bytes added to warp tile iterator + int uniform_offset_[Detail::kOffsetCount] = {0}; + +public: + + /// Default constructor + TileIteratorTensorOpMixed() = default; + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / AccessType::kElements) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = lane_in_quad ^ (i * 2); + + ptr += column_idx; + + if (i == 0) { + pointers_[0] = ptr; + } + else if (i == 1) { + pointers_[1] = ptr; + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kOffsetCount; ++i) { + uniform_offset_[i] = (i ^ 0) * 4 * sizeof(AccessType); + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / AccessType::kElements; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / AccessType::kElements; + + pointers_[0] += ptr_offset; + pointers_[1] += ptr_offset; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kOffsetCount; ++i) { + uniform_offset_[i] = (i ^ tile_offset.column()) * 4 * sizeof(AccessType); + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int ptr_idx = (n / 4); + int offset_idx = (n % 4); + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1]; + } + + int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements; + +#if 0 + // + // Using inline PTX to avoid generic memory + // + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType) + uniform_offset_[offset_idx]; + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for int32_t x 8 => int8_t/int4b_t x 8 +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + int OutputSizeBits ///< Size of output element in bits +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = int32_t; + using Layout = layout::RowMajor; + static int const kOutputElementCount = 8; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = 2; + + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount] = {nullptr}; + + /// Stride in units of AccessType + int stride_{0}; + +public: + + /// Default constructor + TileIteratorTensorOpMixed() = default; + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / AccessType::kElements) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = lane_in_quad ^ (i * 2); + + ptr += column_idx; + + if (i == 0) { + pointers_[0] = ptr; + } + else if (i == 1) { + pointers_[1] = ptr; + } + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / AccessType::kElements; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / AccessType::kElements; + + pointers_[0] += ptr_offset; + pointers_[1] += ptr_offset; + + if (tile_offset.column() % 2) { + auto tmp = pointers_[0]; + pointers_[0] = pointers_[1]; + pointers_[1] = tmp; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int ptr_idx = (n / 4); + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1]; + } + + int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements + (n % 4) * 4; + +#if 0 + // + // Using inline PTX to avoid generic memory + // + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType); + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for float x 16 => float_e4m3_t/float_e5m2_t x 16 +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_ ///< matrix multiply operation shape (concept: gemm::GemmShape), +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = float; + using Layout = layout::RowMajor; + static int const kOutputElementCount = 16; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = 2; + + /// Offsets added + static int const kOffsetCount = 4; + + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount] = {nullptr}; + + /// Stride in units of AccessType + int stride_{0}; + + /// Uniform offset in bytes added to warp tile iterator + int uniform_offset_[Detail::kOffsetCount] = {0}; + +public: + + /// Default constructor + TileIteratorTensorOpMixed() = default; + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / AccessType::kElements) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = lane_in_quad ^ (i * 2); + + ptr += column_idx; + + if (i == 0) { + pointers_[0] = ptr; + } + else if (i == 1) { + pointers_[1] = ptr; + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kOffsetCount; ++i) { + uniform_offset_[i] = (i ^ 0) * 4 * sizeof(AccessType); + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / AccessType::kElements; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / AccessType::kElements; + + pointers_[0] += ptr_offset; + pointers_[1] += ptr_offset; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kOffsetCount; ++i) { + uniform_offset_[i] = (i ^ tile_offset.column()) * 4 * sizeof(AccessType); + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int ptr_idx = (n / 4); + int offset_idx = (n % 4); + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1]; + } + + int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements; + +#if 0 + // + // Using inline PTX to avoid generic memory + // + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType) + uniform_offset_[offset_idx]; + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for float x 8 => float_e4m3_t/float_e5m2_t x 8 +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_ ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +class TileIteratorTensorOpMixed { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = float; + using Layout = layout::RowMajor; + static int const kOutputElementCount = 8; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + + /// Number of pointers needed to write accumulators + static int const kPointerCount = 2; + + static_assert(sizeof(Element) == 4, "This can only be used with 32b accumulator data types (f32, s32)."); + }; + + /// Padding quantity + using Padding = MatrixShape<0, Detail::kLanesInQuad * 2>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointers_[Detail::kPointerCount] = {nullptr}; + + /// Stride in units of AccessType + int stride_{0}; + +public: + + /// Default constructor + TileIteratorTensorOpMixed() = default; + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed( + TensorRef const &ref, + unsigned lane_id + ): + stride_(ref.stride()[0] / AccessType::kElements) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + AccessType *ptr = reinterpret_cast(ref.data()) + quad_id * stride_; + int column_idx = lane_in_quad ^ (i * 2); + + ptr += column_idx; + + if (i == 0) { + pointers_[0] = ptr; + } + else if (i == 1) { + pointers_[1] = ptr; + } + } + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_pointer_offset(Index pointer_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int64_t i = 0; i < Detail::kPointerCount; ++i) { + pointers_[i] += pointer_offset / AccessType::kElements; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & add_tile_offset(TensorCoord const &tile_offset) { + + int ptr_offset = tile_offset.row() * Shape::kRow * stride_ + + tile_offset.column() * Shape::kColumn / AccessType::kElements; + + pointers_[0] += ptr_offset; + pointers_[1] += ptr_offset; + + if (tile_offset.column() % 2) { + auto tmp = pointers_[0]; + pointers_[0] = pointers_[1]; + pointers_[1] = tmp; + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpMixed & operator+=(TensorCoord const &tile_offset) { + return add_tile_offset(tile_offset); + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int ptr_idx = (n / 4); + + AccessType *ptr; + if (ptr_idx == 0) { + ptr = pointers_[0]; + } + else if (ptr_idx == 1) { + ptr = pointers_[1]; + } + + int offset = (n / 4) * 16 + pointer_offset / AccessType::kElements + (n % 4) * 4; + +#if 0 + // + // Using inline PTX to avoid generic memory + // + AccessType *smem_ptr = pointers_[ptr_idx]; + smem_ptr[offset] = frag_ptr[n]; +#else + uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr); + uint32_t const *data = reinterpret_cast(frag_ptr + n); + uint32_t offset_in_bytes = offset * sizeof(AccessType); + + asm volatile( + "{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n" + : : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1]) + ); +#endif + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#undef CUTLASS_EPILOGUE_WARP_TILE_ITERATOR_TENSOR_OP_MIXED_OPTIMIZATION_ENABLED + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a18a9ac8f9804da6349512c781174e16f87ce5ed --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -0,0 +1,440 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" +#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) + typename ElementC, ///< Accumulator layout + typename Layout ///< target shared memory layout +> +struct TileIteratorVoltaTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) +> +struct TileIteratorVoltaTensorOp, half_t, layout::RowMajor> { +public: + + using WarpShape = WarpShape_; + using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; + using Element = half_t; + using Layout = layout::RowMajor; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = VoltaTensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// Array type for aligned memory accesses + using AccessType = typename Policy::AccessType; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = typename Policy::Fragment; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = typename Policy::AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Number of elements per access + static int const kElementsPerAccess = Policy::kElementsPerAccess; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + static int const kRowsPerQuad = 4; + static int const kColumnsPerQuad = 8; + static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess; + static int const kAccessQuadDelta = 16; + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Policy::kElementsPerAccess>; + +private: + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + TileIteratorVoltaTensorOp( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / Policy::kElementsPerAccess) { + + int quad_id = lane_id / Detail::kLanesInQuad; + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + int quad_row_idx = ((quad_id & 4) >> 1) + (quad_id & 1); + int quad_col_idx = ((quad_id & 2) >> 1); + + int row = quad_row_idx * Detail::kRowsPerQuad + lane_in_quad; + int column = quad_col_idx * Detail::kColumnsPerQuad; + + pointer_ += layout_({row, column / kElementsPerAccess}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / Policy::kElementsPerAccess; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { + + pointer_ += layout_({ + tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess}); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) { + + int access_quad = access_idx / 2; + int access = access_idx % 2; + + int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess + + access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + + access + pointer_offset / Policy::kElementsPerAccess; + + int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; + + AccessType access_vector = frag_ptr[frag_idx]; + + pointer_[ptr_offset] = access_vector; + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) { + + int access_quad = access_idx / 2; + int access = access_idx % 2; + + int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + + access + pointer_offset / Policy::kElementsPerAccess; + + int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx; + + frag_ptr[frag_idx] = pointer_[ptr_offset]; + } + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment const &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) +> +struct TileIteratorVoltaTensorOp, float, layout::RowMajor> { +public: + + using WarpShape = WarpShape_; + using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; + using Element = float; + using Layout = layout::RowMajor; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = VoltaTensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// Array type for aligned memory accesses + using AccessType = typename Policy::AccessType; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = typename Policy::Fragment; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = typename Policy::AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Number of elements per access + static int const kElementsPerAccess = Policy::kElementsPerAccess; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + static int const kRowsPerQuad = 4; + static int const kColumnsPerQuad = 8; + static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess; + static int const kAccessQuadDelta = 16; + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Policy::kElementsPerAccess>; + +private: + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_DEVICE + TileIteratorVoltaTensorOp( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / Policy::kElementsPerAccess) { + + int quad_id = lane_id / Detail::kLanesInQuad; + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + int const kQuadRowDelta = 4; + int const kQuadColumnDelta = 2 * Policy::MmaIterations::kColumn; + + int quad_row_offset = ((quad_id & 4) / 2 + (quad_id & 1)) * kQuadRowDelta; + int quad_column_offset = (quad_id & 2) / 2 * kQuadColumnDelta; + + int thread_row_offset = (lane_in_quad & 1); + int thread_column_offset = (lane_in_quad & 2) / 2; + + int row = quad_row_offset + thread_row_offset; + int column = quad_column_offset + thread_column_offset; + + pointer_ += layout_({row, column}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / Policy::kElementsPerAccess; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { + + pointer_ += layout_({ + tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn / Policy::kElementsPerAccess}); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + int const kAccessesPerRow = Policy::TileIterations::kColumn * Policy::MmaIterations::kColumn * 2; + + CUTLASS_PRAGMA_UNROLL + for (int row_idx = 0; row_idx < Policy::kRowsPerMmaTile; ++row_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < kAccessesPerRow; ++access_idx) { + + int frag_idx = row_idx * kAccessesPerRow + access_idx; + + int ptr_column_offset = (access_idx & 1) * 2 + + (access_idx & 2) * Policy::MmaIterations::kColumn * 2 + + (access_idx & 4) * Policy::MmaIterations::kColumn * 2; + + int ptr_row_offset = row_idx * 2; + + int ptr_offset = layout_({ptr_row_offset, ptr_column_offset}) + pointer_offset / Policy::kElementsPerAccess; + + pointer_[ptr_offset] = frag_ptr[frag_idx]; + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + assert(0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment const &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8129dce1d80d805054c2c35a83797379522c3121 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h @@ -0,0 +1,224 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/wmma_array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/warp/wmma_tensor_op_policy.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorFragment, ///< wmma fragment to be written (concept: nvcuda::wmma::fragment) + typename Layout ///< target shared memory layout +> +class TileIteratorWmmaTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template for reading and writing tiles of accumulators to shared memory +template < + typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorFragment_ ///< wmma fragment to be written (concept: nvcuda::wmma::fragment) +> +class TileIteratorWmmaTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorFragment = OperatorFragment_; + using Layout = layout::RowMajor; + + // + // Derived types + // + using WmmaDataType = typename OperatorFragment::element_type; + using Element = typename cutlass::arch::WmmaToCutlassDataType::Type; ///< Data Type of element stored in nvcuda::wmma::frament + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = WmmaTensorOpPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = WmmaFragmentArray; + + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + + /// Padding quantity + // (Epilogue shared memory padding for WMMA Gemm kernel is set to run optimaly on Turing) + using Padding = MatrixShape< + 0, + 4 * Policy::kElementsPerAccess + >; + +private: + + /// Storage type for accessing memory + //using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to shared memory + TensorRef ref_; + + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorWmmaTensorOp(): ref_(nullptr) { + + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorWmmaTensorOp( + TensorRef const &ref, + unsigned lane_id + ): ref_(ref) { + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorWmmaTensorOp & add_pointer_offset(Index pointer_offset) { + ref_.add_pointer_offset(pointer_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorWmmaTensorOp & add_tile_offset(TensorCoord const &tile_offset) { + ref_.add_coord_offset({tile_offset.row() * OperatorShape::kM, tile_offset.column() * WarpShape::kN}); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorWmmaTensorOp & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + for(int n=0; n < Policy::OperatorCount::kColumn; n++) { + + WmmaDataType* ptr = reinterpret_cast (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset); + + nvcuda::wmma::store_matrix_sync( + ptr, + frag[n], + ref_.stride()[0], + nvcuda::wmma::layout_t::mem_row_major + ); + + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + for(int n=0; n < Policy::OperatorCount::kColumn; n++) { + + WmmaDataType* ptr = reinterpret_cast (ref_.data() + ref_.offset({0, n * OperatorShape::kN}) + pointer_offset); + + nvcuda::wmma::load_matrix_sync( + frag[n], + ptr, + ref_.stride()[0], + nvcuda::wmma::layout_t::mem_row_major + ); + + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..c108fc91cab2349cea54c758a3b19237aa7b692d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/volta_tensor_op_policy.h @@ -0,0 +1,195 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. + These quantities assume a 'column-major' arrangement of TensorOp instructions, of which + a row-oriented slice is visible per iteration. +*/ + +#pragma once + +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Policy details related to the epilogue +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename InterleavedTileShape, ///< shape of indivisible instruction-level arrangement (concept: GemmShape) + typename ElementC, ///< Accumulator layout + typename Layout ///< target shared memory layout +> +struct VoltaTensorOpPolicy; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major +template < + typename WarpShape_ ///< shape of warp-level GEMM (concept: GemmShape) +> +struct VoltaTensorOpPolicy, half_t, layout::RowMajor> { + + using WarpShape = WarpShape_; + using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; + using ElementC = half_t; + using Layout = layout::RowMajor; + + /// Shape of one warp-levelinstruction + using InstructionShape = gemm::GemmShape<16, 16, 4>; + + /// Number of mma operations performed for one 32x32x4 interleaved tile + using MmaIterations = MatrixShape< + InterleavedTileShape::kM / InstructionShape::kM, + InterleavedTileShape::kN / InstructionShape::kN + >; + + /// Number of 32x32x4 interleaved tiles performed to cover the warp-level GEMM shape + using TileIterations = MatrixShape< + WarpShape::kM / InterleavedTileShape::kM, + WarpShape::kN / InterleavedTileShape::kN + >; + + /// Number of accumulator elements owned by each thread per Mma + static int const kElementsPerMma = 8; + static int const kRowsPerIteration = 16; + + // + // Hard-coded constants regarding Tensor Operations + // + + /// Number of accumulator elements stored per memory instruction to shared memory + static int const kElementsPerAccess = 4; + + /// Number of accesses performed per interleaved tile + static int const kAccessesPerInterleavedTile = 4; + + /// Total number of iterations needed to cover the entire tile + static int const kIterations = TileIterations::kRow * 2; + + // + // Derived types + // + + /// Array type for aligned memory accesses + using AccessType = AlignedArray; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + ElementC, + kElementsPerAccess * kAccessesPerInterleavedTile * TileIterations::kColumn>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + ElementC, + TileIterations::kCount * MmaIterations::kCount * kElementsPerMma>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major +template < + typename WarpShape_ ///< shape of warp-level GEMM (concept: MatrixShape) +> +struct VoltaTensorOpPolicy, float, layout::RowMajor> { + + using WarpShape = WarpShape_; + using InterleavedTileShape = gemm::GemmShape<32, 32, 4>; + using ElementC = float; + using Layout = layout::RowMajor; + + /// Shape of one warp-levelinstruction + using InstructionShape = gemm::GemmShape<16, 16, 4>; + + /// Number of mma operations performed for one 32x32x4 interleaved tile + using MmaIterations = MatrixShape< + InterleavedTileShape::kM / InstructionShape::kM, + InterleavedTileShape::kN / InstructionShape::kN + >; + + /// Number of 32x32x4 interleaved tiles performed to cover the warp-level GEMM shape + using TileIterations = MatrixShape< + WarpShape::kM / InterleavedTileShape::kM, + WarpShape::kN / InterleavedTileShape::kN + >; + + /// Number of accumulator elements owned by each thread per Mma + static int const kElementsPerMma = 8; + static int const kRowsPerIteration = 16; + + // + // Hard-coded constants regarding Tensor Operations + // + + /// Number of accumulator elements stored per memory instruction to shared memory + static int const kElementsPerAccess = 2; + + /// Number of accesses performed per interleaved tile + static int const kAccessesPerInterleavedTile = 8; + + /// Number of rows per interleaved tile + static int const kRowsPerMmaTile = 2; + + /// Total number of iterations needed to cover the entire tile + static int const kIterations = TileIterations::kRow * MmaIterations::kRow; + + // + // Derived types + // + + /// Array type for aligned memory accesses + using AccessType = AlignedArray; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + ElementC, + kElementsPerAccess * kAccessesPerInterleavedTile * TileIterations::kColumn>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + ElementC, + TileIterations::kCount * MmaIterations::kCount * kElementsPerMma>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..01b1e72e52181a2556720340f2483716f24264c2 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/warp/wmma_tensor_op_policy.h @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic structures needed for implementing the warp-scoped phase of the epilogue. + These quantities assume a 'column-major' arrangement of TensorOp instructions, of which + a row-oriented slice is visible per iteration. +*/ + +#pragma once + +#include "cutlass/arch/wmma.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy details related to the epilogue +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm:GemmShape) + typename Layout ///< target shared memory layout +> +struct WmmaTensorOpPolicy; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +struct WmmaTensorOpPolicy { + + /// Number of operations + using OperatorCount = MatrixShape< + WarpShape::kM / OperatorShape::kM, + WarpShape::kN / OperatorShape::kN + >; + + // + // Hard-coded constants regarding Tensor Operations + // + static int const kElementsPerAccess = 2; + static int const kRowsPerIteration = OperatorShape::kM; + static int const kWmmaFragmentsPerAccess = 1; + + // + // Derived quantities + // + + // Number of externally visible iterations + static int const kIterations = OperatorCount::kRow; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// + +#endif + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/exmy_base.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/exmy_base.h new file mode 100644 index 0000000000000000000000000000000000000000..be207a4952ead88b1f6717fd1e66728e351f8bf1 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/exmy_base.h @@ -0,0 +1,1222 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! + \file + \brief Generic floating-point type for ExMy format +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_size.h" +#include "cutlass/platform/platform.h" + +// #define CUTLASS_DEBUG_TRACE_LEVEL 2 +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + // Helper functions +namespace detail { + +template +CUTLASS_HOST_DEVICE +Dst copy_bits(Src src) +{ + Dst dst; + static_assert(sizeof(Src) <= sizeof(Dst), "Dst type should be at least the same size as Src type"); + static_assert(cutlass::platform::is_trivially_copyable::value, "Dst type should be trivially copyable"); + static_assert(cutlass::platform::is_trivially_copyable< + /*cutlass::platform::remove_cvref_t< */ Dst /* > */ >::value, "Dst type should be trivially copyable"); + memcpy(&dst, &src, sizeof(src)); + return dst; +} + +enum class NanInfEncoding +{ + // IEEE-754 style NaN. Exponent bits are + // all ones, and at least one bit of mantissa is one + IEEE_754, + // Canonical NaN. There is only one value representing NaN and + // no Inf is defined. + CANONICAL_ONLY, + // No NaN or Inf encoded. + NONE +}; + +enum class FpEncoding +{ + E11M52, // double + E8M23, // float + E5M2, // FP8 + E4M3, // FP8 + UE4M3, // FP8 + UE8M0, // FP8 + E3M2, // FP6 + E2M3, // FP6 + E2M1, // FP4 +}; + +////// + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr int exponent_bias_cxx17() { + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + static_assert(NumMantissaBits <= static_cast(cutlass::platform::numeric_limits::max())); + return -1 * static_cast(NumMantissaBits); + } + else { + return static_cast((1 << (NumExpBits - 1))) - 1; + } + + CUTLASS_GCC_UNREACHABLE; +} +#endif + +namespace impl { +template +constexpr int shift_num_bits_expression_cxx11() { +#if (CUTLASS_CXX17_OR_LATER) + static_assert(NumExpBitsMinusOne <= 31u); +#endif + return NumExpBitsMinusOne > 31u ? 31u : NumExpBitsMinusOne; +} + +template +constexpr int inner_shift_expression_cxx11() { + return static_cast((1u << shift_num_bits_expression_cxx11()) - 1u); +} + +} // namespace impl + +// C++11 equivalent of exponent_bias_cxx17() +template +constexpr int exponent_bias_cxx11() { +#if (CUTLASS_CXX17_OR_LATER) + return exponent_bias_cxx17(); +#else + return (NumExpBits == 0) ? + -1 * static_cast(NumMantissaBits) : impl::inner_shift_expression_cxx11(); +#endif +} + +// C++11 equivalent of maximum_exponent_cxx17() +template +constexpr int maximum_exponent_cxx11() { + return + ((NumExpBits == 0) ? + (0 - exponent_bias_cxx11()) : + ((NaNEncoding == NanInfEncoding::IEEE_754) ? + ((static_cast((1 << NumExpBits)) - 2) - exponent_bias_cxx11()) : + ((NaNEncoding == NanInfEncoding::CANONICAL_ONLY) ? + ((NumMantissaBits > 0) ? + static_cast((1 << NumExpBits)) - 1 - exponent_bias_cxx11() : + static_cast((1 << NumExpBits)) - 2 - exponent_bias_cxx11() + ) : + (static_cast((1 << NumExpBits)) - 1 - exponent_bias_cxx11()) + ) + ) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr int maximum_exponent_cxx17() { + constexpr int exp_bias = exponent_bias_cxx17(); + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + // If no exponent bits, return fixed hidden bias + return 0 - exp_bias; + } + else { + if CUTLASS_CONSTEXPR_IF_CXX17 (NaNEncoding == NanInfEncoding::IEEE_754) { + // We have IEEE style NaN and infinity + // All values when exp_bits = 1...1s are used. + int max_exp_bits = static_cast((1 << NumExpBits)) - 2; + return max_exp_bits - exp_bias; + } + else { + // There are no cases where we have Inf without IEEE_754_Nan + + // If we have a canonical NaN. Only exp=1..1 and mantissa=1..1 + // value has a special meaning. If we also have at least one mantissa + // bit, then maximum exponent is 1...1 - exponent_bias + if CUTLASS_CONSTEXPR_IF_CXX17 (NaNEncoding == NanInfEncoding::CANONICAL_ONLY) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NumMantissaBits > 0) { + int max_exp_bits = static_cast((1 << NumExpBits)) - 1; + return max_exp_bits - exp_bias; + } + else { // no mantissa bits + int max_exp_bits = static_cast((1 << NumExpBits)) - 2; + return max_exp_bits - exp_bias; + } + } + // No NaNs or infs + int max_exp_bits = static_cast((1 << NumExpBits)) - 1; + return max_exp_bits - exp_bias; + } + } + + CUTLASS_GCC_UNREACHABLE; +} +#endif + +template +constexpr int minimum_exponent_cxx11() { + return + ((NumExpBits == 0) ? + 0 - exponent_bias_cxx11() : + ((NumMantissaBits > 0) ? + 1 - exponent_bias_cxx11() : + 0 - exponent_bias_cxx11()) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr int minimum_exponent_cxx17() { + constexpr int exp_bias = exponent_bias_cxx17(); + constexpr bool has_denorm = (NumMantissaBits > 0); + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + // If no exponent bits, return fixed hidden bias + // Note that minimum and maximum exponents are the same. + return 0 - exp_bias; + } + + if CUTLASS_CONSTEXPR_IF_CXX17 (has_denorm) { + // Exp = 0...0s is reserved for denorm values. + return 1 - exp_bias; + } + return 0 - exp_bias; +} +#endif + +template +constexpr Storage max_pos_denormal_value_cxx11() { + static_assert(NumExpBits > 0 || NumMantissaBits > 0, "Both NumExpBits and NumMantissaBits can't be zero"); + return + (!(NumMantissaBits > 0) ? Storage(0) : Storage((1ull << NumMantissaBits) - 1)); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage max_pos_denormal_value_cxx17() { + static_assert(NumExpBits > 0 || NumMantissaBits > 0, "Both NumExpBits and NumMantissaBits can't be zero"); + constexpr bool has_denorm = (NumMantissaBits > 0); + if CUTLASS_CONSTEXPR_IF_CXX17 (!has_denorm) { + // If we don't have denormal values, return all 0s + return Storage(0); + } + else { + // Case: (NumExpBits > 0 && NumMantissaBits > 0) or (NumExpBits == 0 && NumMantissaBits > 0) + return Storage((1ull << NumMantissaBits) - 1); + } + + CUTLASS_GCC_UNREACHABLE; +} +#endif + + +template +constexpr Storage min_pos_denormal_value_cxx11() { + return (!(NumMantissaBits > 0) ? Storage(0) : Storage(1)); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage min_pos_denormal_value_cxx17() { + constexpr bool has_denorm = (NumMantissaBits > 0); + if CUTLASS_CONSTEXPR_IF_CXX17 (!has_denorm) { + // If we don't have denormal values, return all 0s + return Storage(0); + } + // Case: (NumExpBits > 0 && NumMantissaBits > 0) or (NumExpBits == 0 && NumMantissaBits > 0) + return Storage(1); +} +#endif + +template +constexpr Storage max_pos_normal_value_cxx11() { + return + ((NumExpBits == 0) ? + Storage(0) : + ((NumMantissaBits == 0) ? + 0 : + (((NaNEncoding == NanInfEncoding::IEEE_754 || NaNEncoding == NanInfEncoding::NONE) ? + ((1ull << NumMantissaBits) - 1) : + ((1ull << NumMantissaBits) - 2))) + ) | (static_cast( + maximum_exponent_cxx11() + + exponent_bias_cxx11() + ) << NumMantissaBits) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage max_pos_normal_value_cxx17() { + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + // if there are no exponent bits, we don't have normal values. + return Storage(0); + } + constexpr int exp_bias = exponent_bias_cxx17(); + constexpr int max_exp = maximum_exponent_cxx17(); + constexpr int exp = max_exp + exp_bias; + + // place the exponent + Storage val = static_cast(exp) << NumMantissaBits; + // If there are no mantissa bits return the exponent + if CUTLASS_CONSTEXPR_IF_CXX17 (NumMantissaBits == 0) { + return val; + } + else { + // If the NaN Inf encoding follows IEEE 754 or there is no (NaN and Inf) then mantissa can be all 1..1s + if CUTLASS_CONSTEXPR_IF_CXX17 (NaNEncoding == NanInfEncoding::IEEE_754 || + NaNEncoding == NanInfEncoding::NONE ) { + Storage mantissa = (1ull << NumMantissaBits) - 1; + val |= mantissa; + } + else { + // If we have a canonical NaN, then the exponent can be the maximum bit value + // but mantissa=1..1s is reserved for NaN. + Storage mantissa = (1ull << NumMantissaBits) - 2; + val |= mantissa; + } + return val; + } + + CUTLASS_GCC_UNREACHABLE; +} +#endif + +template +constexpr Storage min_pos_normal_value_cxx11() { + return + ((NumExpBits == 0) ? + Storage(0) : + (Storage((NumMantissaBits > 0) ? 1 : 0) << NumMantissaBits) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage min_pos_normal_value_cxx17() { + constexpr bool has_denorm = (NumMantissaBits > 0); + + if CUTLASS_CONSTEXPR_IF_CXX17 (NumExpBits == 0) { + // if there are no exponent bits, we don't have normal values. + return Storage(0); + } + Storage exp = 0; + if CUTLASS_CONSTEXPR_IF_CXX17 (has_denorm) { + exp = 1; + } + return static_cast(exp << NumMantissaBits); +} +#endif + +template +constexpr Storage max_value_cxx11() { + return + ((NumExpBits > 0) ? + max_pos_normal_value_cxx11() : + max_pos_denormal_value_cxx11() + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage max_value_cxx17() { + constexpr bool has_normal = (NumExpBits > 0); + if CUTLASS_CONSTEXPR_IF_CXX17 (has_normal) { + return max_pos_normal_value_cxx17(); + } + else { + return max_pos_denormal_value_cxx17(); + } + + CUTLASS_GCC_UNREACHABLE; +} +#endif + +template +constexpr Storage min_value_cxx11() { + return + (IsSigned ? + Storage(1ull << (NumExpBits + NumMantissaBits)) | max_value_cxx11() : + Storage(0) + ); +} + +#if (CUTLASS_CXX17_OR_LATER) +template +constexpr Storage min_value_cxx17() { + if (IsSigned) { + return Storage(1ull << (NumExpBits + NumMantissaBits)) | max_value_cxx17(); + } + else { // Unsigned number + return Storage(0); + } + + CUTLASS_GCC_UNREACHABLE; +} +#endif + +template < + class StorageType, + uint32_t NumBits, uint32_t NumExpBits, uint32_t NumMantissaBits, + NanInfEncoding Nan = NanInfEncoding::IEEE_754, bool IsSigned = true> +struct FpBitRepresentation { +public: + + using Storage = StorageType; + +#if (CUTLASS_CXX17_OR_LATER) + static_assert(cutlass::platform::is_unsigned_v, "Use an unsigned integer for StorageType"); +#endif + static constexpr bool IS_SIGNED = IsSigned; + // Canonical NaN is always represented as exponent=11...11 and mantissa=11...11, if it exists + static constexpr NanInfEncoding NAN_TYPE = Nan; + // Inf is always represented as exponent=11...11 and mantissa=00...00, if it exists + static constexpr bool HAS_INF = (NAN_TYPE == NanInfEncoding::IEEE_754); + static constexpr bool HAS_NAN = (NAN_TYPE != NanInfEncoding::NONE); + + static constexpr bool HAS_DENORM = (NumMantissaBits > 0); + static constexpr bool HAS_NORMAL = !HAS_DENORM; + + static constexpr uint32_t NUM_BITS = NumBits; + static constexpr uint32_t NUM_EXPONENT_BITS = NumExpBits; + static constexpr uint32_t NUM_MANTISSA_BITS = NumMantissaBits; + static_assert(NUM_BITS >= (NUM_EXPONENT_BITS + NUM_MANTISSA_BITS + uint32_t(IS_SIGNED)), "Number of bits do not match"); + + static constexpr Storage ONE = Storage(1); + static constexpr Storage ZERO = Storage(0); + + // Note: Don't rely on operator precedence. Use parenthesis. + static constexpr Storage EXPONENT_MASK = (Storage(1) << Storage(NUM_EXPONENT_BITS)) - ONE; + static constexpr Storage MANTISSA_MASK = (Storage(1) << Storage(NUM_MANTISSA_BITS)) - ONE; + static constexpr Storage EXPONENT_SHIFT = Storage(NUM_MANTISSA_BITS); + static constexpr Storage SIGN_SHIFT = (IS_SIGNED) ? Storage(NUM_MANTISSA_BITS + NUM_EXPONENT_BITS) : Storage(0); + + // Note: All biased/real exponent calculation are done with signed ints + // Use unsigned to represent data not exponent. + static constexpr int EXP_BIAS = detail::exponent_bias_cxx11(); + static constexpr int MAX_EXP = detail::maximum_exponent_cxx11(); + static constexpr int MIN_EXP = detail::minimum_exponent_cxx11(); + + // Floating-point Limits + static constexpr Storage MAX_POS_NORMAL_VAL = detail::max_pos_normal_value_cxx11(); + static constexpr Storage MAX_POS_DENORMAL_VAL = detail::max_pos_denormal_value_cxx11(); + static constexpr Storage MIN_POS_NORMAL_VAL = detail::min_pos_normal_value_cxx11(); + static constexpr Storage MIN_POS_DENORMAL_VAL = detail::min_pos_denormal_value_cxx11(); + + static constexpr Storage MAX_VALUE = max_value_cxx11(); + static constexpr Storage MIN_VALUE = min_value_cxx11(); + + // + // C++17 Verification + // +#if (CUTLASS_CXX17_OR_LATER) + static_assert(EXP_BIAS == detail::exponent_bias_cxx17(), "Error"); + static_assert(MAX_EXP == detail::maximum_exponent_cxx17(), "Error"); + static_assert(MIN_EXP == detail::minimum_exponent_cxx17(), "Error"); + + static_assert(MAX_POS_NORMAL_VAL == detail::max_pos_normal_value_cxx17(), "Error"); + static_assert(MAX_POS_DENORMAL_VAL == detail::max_pos_denormal_value_cxx17(), "Error"); + static_assert(MIN_POS_NORMAL_VAL == detail::min_pos_normal_value_cxx17(), "Error"); + static_assert(MIN_POS_DENORMAL_VAL == detail::min_pos_denormal_value_cxx17(), "Error"); + static_assert(MAX_VALUE == max_value_cxx17(), "Error"); + static_assert(MIN_VALUE == min_value_cxx17(), "Error"); +#endif + + // If we don't have INF defined, set the largest number. Gives us .satfinite behavior. + static constexpr Storage INF_MASK = (HAS_INF) ? + (Storage(EXPONENT_MASK) << Storage(NUM_MANTISSA_BITS)) : MAX_VALUE; + static constexpr Storage NAN_MASK = (Storage(EXPONENT_MASK) << Storage(NUM_MANTISSA_BITS)) | MANTISSA_MASK; + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_inf(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!HAS_INF) { + return false; + } + bool exp_all_ones = (exponent_bits(flt) ^ EXPONENT_MASK) == 0; + bool mantissa_all_zeros = mantissa_bits(flt) == 0; + return exp_all_ones && mantissa_all_zeros; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_canonical_nan(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NAN_TYPE == NanInfEncoding::NONE) { + return false; + } + bool exp_all_ones = (exponent_bits(flt) ^ EXPONENT_MASK) == ZERO; + bool mantissa_all_ones = (mantissa_bits(flt) ^ MANTISSA_MASK) == ZERO; + return exp_all_ones && mantissa_all_ones; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_nan(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NAN_TYPE == NanInfEncoding::NONE) { + return false; + } + + if CUTLASS_CONSTEXPR_IF_CXX17 (NAN_TYPE == NanInfEncoding::CANONICAL_ONLY) { + return is_canonical_nan(flt); + } + + bool exp_all_ones = (exponent_bits(flt) ^ EXPONENT_MASK) == ZERO; + bool mantissa_has_ones = mantissa_bits(flt) > ZERO; + return exp_all_ones && mantissa_has_ones; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 bool is_denorm(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!HAS_DENORM) { + return false; + } + else if (exponent_bits(flt) == ZERO) { + // Exponent bits are all 0s + return true; + } + return false; + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T sign_bit(T flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!IS_SIGNED) { + return T(0); + } + return static_cast(flt >> T(SIGN_SHIFT)); + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T set_sign_bit(T flt, T sign) { + if CUTLASS_CONSTEXPR_IF_CXX17 (!IS_SIGNED) { + return flt; + } + return static_cast(flt | (sign << T(SIGN_SHIFT))); + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage exponent_bits(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_EXPONENT_BITS == ZERO) { + return ZERO; + } + return (flt >> (NUM_MANTISSA_BITS)) & EXPONENT_MASK; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 int exponent(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_EXPONENT_BITS == ZERO) { + return -int(EXP_BIAS); + } + + if (HAS_DENORM && (exponent_bits(flt) == ZERO)) { + return 1 - int(EXP_BIAS); + } + + return int(flt >> (NUM_MANTISSA_BITS) & EXPONENT_MASK) - int(EXP_BIAS); + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage mantissa_bits(Storage flt) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_MANTISSA_BITS == ZERO) { + return ZERO; + } + return (flt & MANTISSA_MASK); + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage to_bits(FpType flt) { + return copy_bits(flt); + } + + template + CUTLASS_HOST_DEVICE static typename DstFpBits::Storage convert_to( + Storage src_val, + DstFpBits dst_encoding) { + return convert(FpBitRepresentation{}, src_val, dst_encoding); + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage convert_from( + typename SrcFpBits::Storage src_val, + SrcFpBits src_encoding) { + return convert(src_encoding, src_val, FpBitRepresentation{}); + } + +private: + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T make_fp_from_bits(T sign, T exp, T mantissa) { + T fp_bits = T(ZERO); + CUTLASS_UNUSED(sign); + if CUTLASS_CONSTEXPR_IF_CXX17 (IS_SIGNED) { + fp_bits = sign << SIGN_SHIFT; + } + fp_bits |= (exp << T(NUM_MANTISSA_BITS)); + fp_bits |= (mantissa); + return fp_bits; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage nan_with_sign(Storage sign) { + Storage fp_bits = NAN_MASK; + return set_sign_bit(fp_bits, sign); + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage inf_with_sign(Storage sign) { + if CUTLASS_CONSTEXPR_IF_CXX17 (HAS_INF) { + Storage fp_bits = INF_MASK; + return set_sign_bit(fp_bits, sign); + } + else { + // If INF is not defined assume satfinite behavior + return (sign == ZERO) ? MAX_VALUE : MIN_VALUE; + } + + CUTLASS_GCC_UNREACHABLE; + } + + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 Storage significand(Storage flt) { + if (is_denorm(flt)) { + return mantissa_bits(flt); + } + else { + return (ONE << Storage(NUM_MANTISSA_BITS)) | mantissa_bits(flt); + } + + CUTLASS_GCC_UNREACHABLE; + } + + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T significand_hidden_bits(T significand) { + if CUTLASS_CONSTEXPR_IF_CXX17 (NUM_MANTISSA_BITS == 0) { + return T(1); + } + return ((T(0b11) << T(NUM_MANTISSA_BITS)) & significand) >> T(NUM_MANTISSA_BITS); + } + + // Current assumption round to nearest even + template + CUTLASS_HOST_DEVICE + static CUTLASS_CONSTEXPR_IF_CXX17 T round_significand(T src, int shift_amount) { + T dst_mantissa = src; + // If the shift amount is positive, we are shifting left + // Type with less mantissa bits is rounded to a type with more + // mantissa bits. + if (shift_amount > 0) { + dst_mantissa = (dst_mantissa << (shift_amount)); + } + else { + // There are fewer mantissa bits in the target type + // we need to round the destination number up for all + // lower precision bits removed. + // We assume round-to-nearest-even here. + int pos_shift_amount = -shift_amount; + + // Too large shift return all zeros to prevent undefined behavior for shift. + if (pos_shift_amount >= static_cast(sizeof(T) * 8)) { + return T(0); + } + + T guard_bit_mask = (T(1) << T(pos_shift_amount)); // Last bit to remain in mantissa + T sticky_mask = (T(1) << T(pos_shift_amount - 1)) - T(1); // Remaining bits + T round_bit_mask = (T(1) << T(pos_shift_amount - 1)); // First bit removed from mantissa + + bool sticky_bit = (src & sticky_mask) >= T(1); // ORing all sticky bits + bool round_bit = (src & round_bit_mask) >= T(1); + bool guard_bit = (src & guard_bit_mask) >= T(1); + + // Shift mantissa bits to right to remove lowest precision bits + dst_mantissa = dst_mantissa >> pos_shift_amount; + + if ((sticky_bit && round_bit) || (guard_bit && round_bit && !sticky_bit)) { + dst_mantissa += 1; + } + } + return dst_mantissa; + } + + template + CUTLASS_HOST_DEVICE + static typename DstFpBits::Storage convert( + SrcFpBits src_encoding, + typename SrcFpBits::Storage src_val, + DstFpBits dst_encoding) { + + using SrcT = typename SrcFpBits::Storage; + using DstT = typename DstFpBits::Storage; + using LargeStorage = typename cutlass::platform::conditional<(sizeof(SrcT) > sizeof(DstT)), SrcT, DstT>::type; + + LargeStorage src_sign_bit = src_encoding.sign_bit(src_val); + + // If the source is NaN, set the destination to NaN carrying the sign bit + if (src_encoding.is_nan(src_val)) { + return dst_encoding.nan_with_sign(DstT(src_sign_bit)); + } + // If the source is INF, set the destination to INF carrying the sign bit + else if (src_encoding.is_inf(src_val)) { + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + // Number is not NaN or INF: Zero and others + + LargeStorage src_exp_bits = src_encoding.exponent_bits(src_val); + LargeStorage src_significand = src_encoding.significand(src_val); + int src_exp = src_encoding.exponent(src_val); + + // The source value is 0. Return a signed 0. + if (src_exp_bits == LargeStorage(0) && src_significand == LargeStorage(0)) { + return dst_encoding.set_sign_bit(DstT(0), DstT(src_sign_bit)); + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(1) src_sign: %llu src_exp_bits %llx src_exp %d src_significand %llx\n", + static_cast(src_sign_bit), static_cast(src_exp_bits), src_exp, static_cast(src_significand)); +#endif + // Normalize the number: Left shift the significand bits until hidden "1" appears. + // Only needed if the src value is denormal. + // Conditions: + // If the exponent is 0, then the significand can't be 0 (src_val==0 case handled above): + // there is at least one "1" bit in the significand. Loop executes. + // If the exponent is not 0, then the number is normal: + // significand has hidden bit set. Loop doesn't execute. + // Assumption: Zero is always defined for the floating point types and detected above + + while (src_encoding.significand_hidden_bits(src_significand) == LargeStorage(0)) { + src_significand <<= LargeStorage(1); + src_exp--; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(2) src_sign: %llu src_exp_bits %llx src_exp %d src_significand %llx\n", + static_cast(src_sign_bit), static_cast(src_exp_bits), src_exp, static_cast(src_significand)); +#endif + // The exponent exceeds DstFormat's exponent capacity + // Return positive/negative infinity. + // If no INF is defined, return positive/negative largest value. + if (src_exp > DstFpBits::MAX_EXP) { + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + else if (src_exp <= DstFpBits::MAX_EXP && src_exp >= DstFpBits::MIN_EXP) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(3) Exp match: src_sign: %d src_exp_bits: %x src_exp: %d src_significand: %x\n", + src_sign_bit, src_exp_bits, src_exp, src_significand); +#endif + + int shift_amount = int(DstFpBits::NUM_MANTISSA_BITS) - int(SrcFpBits::NUM_MANTISSA_BITS); + int dst_exponent = src_exp + DstFpBits::EXP_BIAS; + LargeStorage dst_mantissa = src_significand; + + // if we have an M0 case, the floating point number is always denormal. + // Therefore, if exponents are equal, we need to check whether it is inf + if (DstFpBits::NUM_EXPONENT_BITS == 0) { + if (dst_mantissa > DstFpBits::INF_MASK) { + return dst_encoding.inf_with_sign(DstT(src_sign_bit)); + } + } + + // Round to nearest even + dst_mantissa = round_significand(dst_mantissa, shift_amount); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(4) after rounding src_sign: %d dst_exponent: %d dst_mantissa: %x\n", + src_sign_bit, dst_exponent, dst_mantissa); +#endif + + if (dst_encoding.significand_hidden_bits(dst_mantissa) > 0b1) { + // Significant became larger than 01.X...X. Divide significand by 2 and multiply exp by 2 + while (dst_exponent < (DstFpBits::MAX_EXP+DstFpBits::EXP_BIAS) && + dst_encoding.significand_hidden_bits(dst_mantissa) > LargeStorage(0b1)) { + dst_mantissa >>= LargeStorage(1); + dst_exponent++; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(5) after rounding max_exp: %d src_sign: %d dst_exponent: %d dst_mantissa: %x\n", + DstFpBits::MAX_EXP,src_sign_bit, dst_exponent, dst_mantissa); +#endif + + if (dst_encoding.significand_hidden_bits(dst_mantissa) > LargeStorage(0b1)) { + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + } + + dst_mantissa = dst_mantissa & DstFpBits::MANTISSA_MASK; + static_assert(sizeof(LargeStorage) >= sizeof(decltype(dst_exponent)), + "sizeof(LargeStorage) must be greater than or equal to sizeof(decltype(dst_exponent))"); + LargeStorage dst_exponent_bits = static_cast(dst_exponent); + + DstT final_val = static_cast(dst_encoding.template make_fp_from_bits(src_sign_bit, dst_exponent_bits, dst_mantissa)); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(6) Final Value src_sign: %d dst_exp_bits: %x dst_mantissa: %x\n", + src_sign_bit, dst_exponent_bits, dst_mantissa); +#endif + + if (DstFpBits::is_nan(final_val)) { + // This NAN is generated when: + // Src is not an Nan + // the exp of Src == the max_exp of Dst. + // The mantissa becomes all-1s after rounding. + // Return max value of Dst (not NAN) as it just couldn't be represented in the range of Dst. + return dst_encoding.set_sign_bit(DstFpBits::INF_MASK, DstT(src_sign_bit)); + } + else { + return final_val; + } + } + else { + // Result is denormal +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(7) Denormal case src_sign: %d src_exp: %d src_significand: %x MIN_EXP: %d\n", + src_sign_bit, src_exp, src_significand, DstFpBits::MIN_EXP); +#endif + + int exp_diff = src_exp - DstFpBits::MIN_EXP; + int shift_amount = int(DstFpBits::NUM_MANTISSA_BITS) - int(SrcFpBits::NUM_MANTISSA_BITS); + shift_amount += exp_diff; + LargeStorage dst_mantissa = src_significand; + dst_mantissa = round_significand(dst_mantissa, shift_amount); + + if (dst_encoding.significand_hidden_bits(dst_mantissa) >= LargeStorage(0b1)) { + if CUTLASS_CONSTEXPR_IF_CXX17 (DstFpBits::NUM_EXPONENT_BITS == 0) { + return dst_encoding.inf_with_sign(DstT(src_sign_bit)); + } + else { + LargeStorage dst_exp_bits = 1; + dst_mantissa &= DstFpBits::MANTISSA_MASK; + DstT final_val = static_cast(dst_encoding.template make_fp_from_bits(src_sign_bit, dst_exp_bits, dst_mantissa)); + return final_val; + } + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(7.1) Denormal case exp_diff: %d shift_amount: %d dst_mantissa %d\n", exp_diff, shift_amount, dst_mantissa); +#endif + dst_mantissa &= DstFpBits::MANTISSA_MASK; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("(8) Final Value src_sign: %d src_exp: %d dst_mantissa: %x\n", + src_sign_bit, src_exp, dst_mantissa); +#endif + + DstT final_val = static_cast(dst_encoding.template make_fp_from_bits(src_sign_bit, LargeStorage(0), dst_mantissa)); + return final_val; + } + + return DstT(0); + } + + template + friend struct FpBitRepresentation; +}; + +#if (CUTLASS_CXX17_OR_LATER) + +template +CUTLASS_CONSTEXPR_IF_CXX17 auto fp_encoding_selector() { + if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E11M52) { // double + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E8M23) { // float + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E5M2) { // FP8 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E4M3) { // FP8 + return cutlass::detail::FpBitRepresentation{}; + } + + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE4M3) { // FP8 + return cutlass::detail::FpBitRepresentation{}; + } + + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE8M0) { // FP8 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E3M2) { // FP6 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E2M3) { // FP6 + return cutlass::detail::FpBitRepresentation{}; + } + else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E2M1) { // FP4 + return cutlass::detail::FpBitRepresentation{}; + } + CUTLASS_GCC_UNREACHABLE; +} + +#else +// +// Definitions for floating point encodings. +// + +template struct FpEncodingSelector { + using type = void; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; + +template <> struct FpEncodingSelector { + using type = cutlass::detail::FpBitRepresentation; +}; +#endif + +} // namespace detail + +template +struct float_exmy_base +{ + + static constexpr detail::FpEncoding Encoding = T; + using BitRepresentation = + #if (CUTLASS_CXX17_OR_LATER) + decltype(detail::fp_encoding_selector()) + #else + typename detail::FpEncodingSelector::type + #endif + ; + + using FP32BitRepresentation = + #if (CUTLASS_CXX17_OR_LATER) + decltype(cutlass::detail::fp_encoding_selector()) + #else + typename detail::FpEncodingSelector::type + #endif + ; + + using Storage = typename BitRepresentation::Storage; + + // + // Data members + // + + /// Data container + Storage storage; + + /// Ctors. + float_exmy_base() = default; + + CUTLASS_HOST_DEVICE + float_exmy_base(Storage s) : storage(s) { + } + + /// Is finite implementation + CUTLASS_HOST_DEVICE + static bool isfinite(float_exmy_base flt) { + return !BitRepresentation::is_inf(flt.storage); + } + + /// Is NaN implementation + CUTLASS_HOST_DEVICE + static bool isnan(float_exmy_base flt) { + return BitRepresentation::is_nan(flt.storage); + } + + /// Is infinite implementation + CUTLASS_HOST_DEVICE + static bool isinf(float_exmy_base flt) { + return BitRepresentation::is_inf(flt.storage); + } + + /// Is infinite implementation + CUTLASS_HOST_DEVICE + static bool isnormal(float_exmy_base flt) { + return !BitRepresentation::is_denorm(flt.storage); + } + + CUTLASS_HOST_DEVICE + static float_exmy_base bitcast(Storage x) { + float_exmy_base f; + f.storage = x; + return f; + } + + CUTLASS_HOST_DEVICE + float_exmy_base convert_from_float(float const &flt) const { + FP32BitRepresentation::Storage fp32_bits = FP32BitRepresentation::to_bits(flt); + float_exmy_base float_exmy; + float_exmy.storage = BitRepresentation::convert_from(fp32_bits, FP32BitRepresentation{}); + return float_exmy; + } + + CUTLASS_HOST_DEVICE + float convert_to_float(float_exmy_base const &x) const { + FP32BitRepresentation::Storage fp32_bits; + fp32_bits = BitRepresentation::convert_to(x.storage, FP32BitRepresentation{}); + return detail::copy_bits(fp32_bits); + } + + // Note: Only consider float/int conversions in this Base class + // Types inheriting from this class should define their own constructors and + // specialized type conversions + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_exmy_base(float x) { + storage = static_cast(this)->convert_from_float(x).storage; + } + + // Integer conversion + CUTLASS_HOST_DEVICE + explicit float_exmy_base(int x) { + storage = static_cast(this)->convert_from_float(float(x)).storage; + } + + CUTLASS_HOST_DEVICE + explicit float_exmy_base(unsigned x) { + storage = static_cast(this)->convert_from_float(float(x)).storage; + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + return static_cast(this)->convert_to_float(*this); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(static_cast(this)->convert_to_float(*this)); + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + Storage &raw() { + return storage; + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + Storage raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return bool(BitRepresentation::sign_bit(storage)); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int(BitRepresentation::exponent_bits(storage)); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return int(BitRepresentation::exponent(storage)); + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(BitRepresentation::mantissa_bits(storage)); + } + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // + // Arithmetic operators + // + /////////////////////////////////////////////////////////////////////////////////////////////////// + + // Note: Almost all data types cast to float then do the arithmetic operations + // Types inheriting from this class can overload them if specialized instructions are available + // in HW (e.g. half_t) + + + CUTLASS_HOST_DEVICE + friend bool operator==(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) == float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator!=(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) != float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator<(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) < float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator<=(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) <= float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator>(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) > float(rhs); + } + + CUTLASS_HOST_DEVICE + friend bool operator>=(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float(lhs) >= float(rhs); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator+(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) + float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator-(float_exmy_base const &lhs) { + return float_exmy_base(-float(lhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator-(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) - float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator*(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) * float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator/(float_exmy_base const &lhs, float_exmy_base const &rhs) { + return float_exmy_base(float(lhs) / float(rhs)); + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator+=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) + float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator-=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) - float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator*=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) * float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator/=(float_exmy_base &lhs, float_exmy_base const &rhs) { + lhs = float_exmy_base(float(lhs) / float(rhs)); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator++(float_exmy_base &lhs) { + float tmp(lhs); + ++tmp; + lhs = float_exmy_base(tmp); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base &operator--(float_exmy_base &lhs) { + float tmp(lhs); + --tmp; + lhs = float_exmy_base(tmp); + return lhs; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator++(float_exmy_base &lhs, int) { + float_exmy_base ret(lhs); + float tmp(lhs); + tmp++; + lhs = float_exmy_base(tmp); + return ret; + } + + CUTLASS_HOST_DEVICE + friend float_exmy_base operator--(float_exmy_base &lhs, int) { + float_exmy_base ret(lhs); + float tmp(lhs); + tmp--; + lhs = float_exmy_base(tmp); + return ret; + } + +}; + +template +CUTLASS_HOST_DEVICE +cutlass::float_exmy_base abs(cutlass::float_exmy_base const& h) { + using BitRepresentation = typename cutlass::float_exmy_base::BitRepresentation; + using Storage = typename cutlass::float_exmy_base::Storage; + return BitRepresentation::IS_SIGNED ? + cutlass::float_exmy_base(Storage(h.raw() & Storage((1<(h.raw()); +} +} // namespace cutlass diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/detail.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/detail.hpp new file mode 100644 index 0000000000000000000000000000000000000000..129f733725d22bdcdfa4b55a9d52afb031adc908 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/detail.hpp @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Distributed gemm device layer helpers. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::distributed::device::detail { + + +cutlass::Status check_cuda_status(cudaError_t status) { + if (status != cudaSuccess) { + auto result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" error message: " << cudaGetErrorString(result)); + return cutlass::Status::kErrorInternal; + } + return cutlass::Status::kSuccess; +} + +// DistGemmBufferHelper computes required buffer size and offsets for GEMM operands. +template < + typename Tiler_, + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementD_> +struct DistGemmBufferHelper { + + using Tiler = Tiler_; + + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementC = ElementC_; + using ElementD = ElementD_; + + static constexpr int NumBuffersA = Tiler::NumBuffersA; + static constexpr int NumBuffersB = Tiler::NumBuffersB; + static constexpr int NumBuffersC = Tiler::NumBuffersC; + static constexpr int NumBuffersD = Tiler::NumBuffersD; + + template + static auto + get_buffer_size_a(ProblemShape problem_shape) { + auto a_buffer_layout = cute::make_layout( + cute::make_shape(NumBuffersA, Tiler::get_local_a_shape(problem_shape), sizeof(ElementA)) + ); + return size(a_buffer_layout); + } + + template + static auto + get_buffer_size_b(ProblemShape problem_shape) { + auto b_buffer_layout = cute::make_layout( + cute::make_shape(NumBuffersB, Tiler::get_local_b_shape(problem_shape), sizeof(ElementB)) + ); + return size(b_buffer_layout); + } + + template + static auto + get_buffer_size_c(ProblemShape problem_shape) { + auto c_buffer_layout = cute::make_layout( + cute::make_shape(NumBuffersC, Tiler::get_local_c_shape(problem_shape), sizeof(ElementC)) + ); + return size(c_buffer_layout); + } + + template + static auto + get_buffer_size_d(ProblemShape problem_shape) { + auto d_buffer_layout = cute::make_layout( + cute::make_shape(NumBuffersD, Tiler::get_local_d_shape(problem_shape), sizeof(ElementD)) + ); + return size(d_buffer_layout); + } + + template + static auto + get_buffer_size(ProblemShape problem_shape) { + size_t buffer_size = 0; + + if constexpr (NumBuffersA > 0) { + buffer_size += get_buffer_size_a(problem_shape); + } + if constexpr (NumBuffersB > 0) { + buffer_size += get_buffer_size_b(problem_shape); + } + if constexpr (NumBuffersC > 0) { + buffer_size += get_buffer_size_c(problem_shape); + } + if constexpr (NumBuffersD > 0) { + buffer_size += get_buffer_size_d(problem_shape); + } + + return buffer_size; + } + + // Buffer space: | buffer_A | buffer_B | buffer_C | buffer_D | + // And buffer_{A,B,C,D}: | iter 1 | iter 2 | ... | iter TP - 1 | + template + static size_t + get_buffer_offset_A(ProblemShape problem_shape) { + return 0; + } + + template + static size_t + get_buffer_offset_B(ProblemShape problem_shape) { + return get_buffer_size_a(problem_shape); + } + + template + static size_t + get_buffer_offset_C(ProblemShape problem_shape) { + return get_buffer_size_a(problem_shape) + get_buffer_size_b(problem_shape); + } + + template + static size_t + get_buffer_offset_D(ProblemShape problem_shape) { + return get_buffer_size_a(problem_shape) + get_buffer_size_b(problem_shape) + get_buffer_size_c(problem_shape); + } +}; + +} // namespace cutlass::distributed::device::detail + +/////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7968849a87d228aef5e5e39afcb705e1595fcd4f --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp @@ -0,0 +1,717 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file Distributed GEMM Device Adapter + + Sets up local GEMM stages, the cuda graph, manages buffer and barrier spaces, + and maps arguments to per-stage arguments. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/experimental/distributed/device/full_barrier.hpp" +#include "cutlass/experimental/distributed/device/detail.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::distributed::device { + +template +class DistributedGemmUniversalAdapter { +public: + using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; + using GemmKernel = GemmKernel_; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // "Inherit" type decls and static values from device GEMM + using LayoutA = typename DeviceGemm::LayoutA; + using LayoutB = typename DeviceGemm::LayoutB; + using LayoutC = typename DeviceGemm::LayoutC; + using LayoutD = typename DeviceGemm::LayoutD; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; + using StrideC = typename GemmKernel::StrideC; + using StrideD = typename GemmKernel::StrideD; + + static bool const kEnableCudaHostAdapter = DeviceGemm::kEnableCudaHostAdapter; + + static ComplexTransform const kTransformA = DeviceGemm::kTransformA; + static ComplexTransform const kTransformB = DeviceGemm::kTransformB; + + using MathOperator = typename DeviceGemm::MathOperator; + using OperatorClass = typename DeviceGemm::OperatorClass; + using ArchTag = typename DeviceGemm::ArchTag; + + using ThreadblockSwizzle = typename DeviceGemm::ThreadblockSwizzle; + using ThreadblockShape = typename DeviceGemm::ThreadblockShape; + using ClusterShape = typename DeviceGemm::ClusterShape; + using InstructionShape = typename DeviceGemm::InstructionShape; + + static int const kThreadCount = DeviceGemm::kThreadCount; + static constexpr int WarpsInMma = DeviceGemm::WarpsInMma; + static constexpr int WarpsInMmaM = DeviceGemm::WarpsInMmaM; + static constexpr int WarpsInMmaN = DeviceGemm::WarpsInMmaN; + + using WarpCount = typename DeviceGemm::WarpCount; + using WarpShape = typename DeviceGemm::WarpShape; + + static int constexpr kStages = DeviceGemm::kStages; + + static int constexpr kAlignmentA = DeviceGemm::kAlignmentA; + static int constexpr kAlignmentB = DeviceGemm::kAlignmentB; + static int constexpr kAlignmentC = DeviceGemm::kAlignmentC; + static int constexpr kAlignmentD = DeviceGemm::kAlignmentD; + + using EpilogueOutputOp = typename DeviceGemm::EpilogueOutputOp; + + static int constexpr kSplitKAlignment = DeviceGemm::kSplitKAlignment; + + // Distributed GEMM types and defs + using DistSchedule = typename GemmKernel::DistSchedule; + static constexpr bool HasMemcpy = DistSchedule::HasMemcpy; + using TP = typename DistSchedule::TP; + static constexpr int TP_ = TP{}; + using ElementFlag = typename GemmKernel::ElementFlag; + using ElementBarrier = uint32_t; + + using BufferHelper = detail::DistGemmBufferHelper< + DistSchedule, + ElementA, + ElementB, + ElementC, + ElementD>; + + /// Argument structure + using Arguments = typename GemmKernel::BaseArguments; + using DistributedArguments = typename GemmKernel::DistributedArguments; + using PackedArguments = typename GemmKernel::PackedArguments; + + /// Argument structure: Kernel API + using Params = typename GemmKernel::PackedParams; + + struct DistributedGemmState { + int device_idx; + + Params params_array[TP_]; + + cudaGraph_t graph; + cudaGraphExec_t graph_executable; + + bool graph_created = false; + bool graph_instantiated = false; + + void * memcpy_source_ptr_array[TP_]; + void const * memcpy_remote_ptr_array[TP_]; + size_t memcpy_bytes[TP_]; + + cutlass::Array device_barrier_ptrs; + + bool is_initialized = false; + }; + +private: + + DistributedGemmState state_; + +public: + + bool is_initialized() { + return state_.is_initialized && state_.graph_created && state_.graph_instantiated; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (args.epilogue.thread.beta != 0.0 && DistSchedule::RemoteC) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Selected TP uses Remote C to communicate " << + "partial results, which do not support non-zero values for beta yet " << + "(epilogue must be sourceless.)\n"); + return Status::kInvalid; + } + + if (not DistSchedule::can_implement_global(args.problem_shape)) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem shape not divisible by TP.\n"); + return Status::kInvalid; + } + + Arguments args_copy = args; + args_copy.problem_shape = DistSchedule::get_local_gemm_shape(args.problem_shape); + for (int iteration = 0; iteration < TP_; ++iteration) { + if (not GemmKernel::can_implement(args_copy)) { + return Status::kInvalid; + } + } + return Status::kSuccess; + } + + /// Gets buffer space size + static size_t + get_buffer_space_size(Arguments const& args) { + size_t buffer_bytes = 0; + + buffer_bytes = BufferHelper::get_buffer_size(args.problem_shape); + buffer_bytes = round_nearest(buffer_bytes, MinWorkspaceAlignment); + + return buffer_bytes; + } + + static auto + get_tensor_A_for_iter(Arguments const* args_array, void** buffer_space, int device_idx, int iteration) { + auto args = args_array[device_idx]; + auto tensor_A = make_tensor(args.mainloop.ptr_A, make_layout( + DistSchedule::get_local_a_shape(args.problem_shape), + args.mainloop.dA)); + + uint8_t* tensor_buffer = reinterpret_cast(buffer_space[device_idx]) + + BufferHelper::get_buffer_offset_A(args.problem_shape); + + return DistSchedule::get_tensor_A(tensor_A, tensor_buffer, device_idx, iteration); + } + + static auto + get_tensor_B_for_iter(Arguments const* args_array, void** buffer_space, int device_idx, int iteration) { + auto args = args_array[device_idx]; + auto tensor_B = make_tensor(args.mainloop.ptr_B, make_layout( + DistSchedule::get_local_b_shape(args.problem_shape), + args.mainloop.dB)); + + uint8_t* tensor_buffer = reinterpret_cast(buffer_space[device_idx]) + + BufferHelper::get_buffer_offset_B(args.problem_shape); + + return DistSchedule::get_tensor_B(tensor_B, tensor_buffer, device_idx, iteration); + } + + static auto + get_tensor_C_for_iter(Arguments const* args_array, void** buffer_space, int device_idx, int iteration) { + auto args = args_array[device_idx]; + auto tensor_C = make_tensor(args.epilogue.ptr_C, make_layout( + DistSchedule::get_local_c_shape(args.problem_shape), + args.epilogue.dC)); + + auto peer_idx_iter = DistSchedule::get_remote_peer_id(device_idx, iteration); + void* buffer_ptr = DistSchedule::RemoteC ? buffer_space[peer_idx_iter] : buffer_space[device_idx]; + + uint8_t* tensor_buffer = reinterpret_cast(buffer_ptr) + + BufferHelper::get_buffer_offset_C(args.problem_shape); + + return DistSchedule::get_tensor_C(tensor_C, tensor_buffer, device_idx, iteration); + } + + static auto + get_tensor_D_for_iter(Arguments const* args_array, void** buffer_space, int device_idx, int iteration) { + auto args = args_array[device_idx]; + auto tensor_D = make_tensor(args.epilogue.ptr_D, make_layout( + DistSchedule::get_local_d_shape(args.problem_shape), + args.epilogue.dD)); + + // support remoteD + uint8_t* tensor_buffer = reinterpret_cast(buffer_space[device_idx]) + + BufferHelper::get_buffer_offset_D(args.problem_shape); + + return DistSchedule::get_tensor_D(tensor_D, tensor_buffer, device_idx, iteration); + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + + workspace_bytes = get_buffer_space_size(args); + + for (int iteration = 0; iteration < TP_; ++iteration) { + // NOTE: assumes underlying kernels align up to alignment requirements on their own, + // and that the alignment requirements of the individual kernels match. + workspace_bytes += GemmKernel::get_workspace_size(args); + } + + return workspace_bytes; + } + + static size_t + get_barrier_bytes() { + return round_nearest(sizeof(ElementBarrier), 32); + } + + static size_t + get_flag_bytes() { + return round_nearest(sizeof(ElementFlag) * TP_, 32); + } + + static void * + exclusive_workspace_ptr_to_flag_ptr(void * exclusive_workspace_ptr, int iteration) { + return static_cast( + static_cast(exclusive_workspace_ptr) + + get_barrier_bytes() + + (sizeof(ElementFlag) * iteration)); + } + + static size_t + get_exclusive_workspace_size() { + return get_barrier_bytes() + get_flag_bytes(); + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const* args, + void** workspace_ptrs, + void** exclusive_workspace_ptrs, + int device_idx, + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) { + + CUTLASS_TRACE_HOST("DistributedGemm::initialize() - stream: " << (stream ? "non-null" : "null")); + + state_.device_idx = device_idx; + + for (int device = 0; device < TP_; ++device) { + state_.device_barrier_ptrs[device] = reinterpret_cast(exclusive_workspace_ptrs[device]); + } + + // Zero out exclusive workspace + zero_workspace(exclusive_workspace_ptrs[device_idx], get_exclusive_workspace_size(), stream, nullptr); + + for (int iteration = 0; iteration < TP_; ++iteration) { + + size_t workspace_iteration_offset = GemmKernel::get_workspace_size(args[device_idx]); + uint8_t* workspace_ptr = reinterpret_cast(workspace_ptrs[device_idx]) + + get_buffer_space_size(args[device_idx]) + + (iteration * workspace_iteration_offset); + + void * workspace_iter = reinterpret_cast(workspace_ptr); + void** buffer_space = workspace_ptrs; + + // Set up GEMM arguments for the current stage/iteration + auto tensor_a_iter = get_tensor_A_for_iter(args, buffer_space, device_idx, iteration); + auto tensor_b_iter = get_tensor_B_for_iter(args, buffer_space, device_idx, iteration); + auto tensor_c_iter = get_tensor_C_for_iter(args, buffer_space, device_idx, iteration); + auto tensor_d_iter = get_tensor_D_for_iter(args, buffer_space, device_idx, iteration); + + Arguments base_args = args[device_idx]; + base_args.problem_shape = DistSchedule::get_local_gemm_shape(args[device_idx].problem_shape); + base_args.mainloop = { + reinterpret_cast(tensor_a_iter.data()), + tensor_a_iter.stride(), + reinterpret_cast(tensor_b_iter.data()), + tensor_b_iter.stride() + }; + base_args.epilogue = { + base_args.epilogue.thread, + reinterpret_cast(tensor_c_iter.data()), + tensor_c_iter.stride(), + reinterpret_cast(tensor_d_iter.data()), + tensor_d_iter.stride() + }; + + if constexpr (DistSchedule::RemoteC) { + if (iteration > 0) { + base_args.epilogue.thread.beta = 1.0; + } + else if (iteration == 0){ + base_args.epilogue.thread.beta = 0.0; + } + } + + auto [left_peer_idx, right_peer_idx] = DistSchedule::get_peers_for_device(device_idx); + auto flag_peer_idx = DistSchedule::KernelWritesArrivalFlag ? right_peer_idx : device_idx; + + void * self_flag_ptr = exclusive_workspace_ptr_to_flag_ptr(exclusive_workspace_ptrs[device_idx], iteration); + void * peer_flag_ptr = exclusive_workspace_ptr_to_flag_ptr(exclusive_workspace_ptrs[flag_peer_idx], iteration); + + DistributedArguments distributed_args = { + device_idx, + iteration, + self_flag_ptr, + peer_flag_ptr + }; + PackedArguments args_iter = {base_args, distributed_args}; + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args_iter, workspace_iter, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + state_.params_array[iteration] = GemmKernel::to_underlying_arguments(args_iter, workspace_iter); + + // Set up peer buffer ptrs + if (iteration > 0 && HasMemcpy) { + auto peer_idx_iter = DistSchedule::get_remote_peer_id(device_idx, iteration); + + void * local_ptr_itr = nullptr; + void const * remote_ptr_itr = nullptr; + size_t local_size = 0; + size_t remote_size = 0; + + static_assert(not DistSchedule::HasMemcpy || ( + DistSchedule::MemcpyA || DistSchedule::MemcpyB), + "Expected to either memcpy A or B when scheduler requires memcpy."); + if constexpr (DistSchedule::MemcpyA) { + local_size = cute::cosize(tensor_a_iter.layout()) * sizeof(ElementA); + local_ptr_itr = reinterpret_cast(tensor_a_iter.data()); + + // Copy peer's slice in the first iteration (direct access memcpy instead of logical ring) + auto remote_tensor_iter = get_tensor_A_for_iter(args, buffer_space, peer_idx_iter, 0); + remote_ptr_itr = reinterpret_cast(remote_tensor_iter.data()); + remote_size = cute::cosize(remote_tensor_iter.layout()) * sizeof(ElementA); + } + else if constexpr (DistSchedule::MemcpyB) { + local_size = cute::cosize(tensor_b_iter.layout()) * sizeof(ElementB); + local_ptr_itr = reinterpret_cast(tensor_b_iter.data()); + + // Copy peer's slice in the first iteration (direct access memcpy instead of logical ring) + auto remote_tensor_iter = get_tensor_B_for_iter(args, buffer_space, peer_idx_iter, 0); + remote_ptr_itr = reinterpret_cast(remote_tensor_iter.data()); + remote_size = cute::cosize(remote_tensor_iter.layout()) * sizeof(ElementB); + } + + assert(local_size == remote_size && local_size > 0); + + state_.memcpy_source_ptr_array[iteration] = local_ptr_itr; + state_.memcpy_remote_ptr_array[iteration] = remote_ptr_itr; + state_.memcpy_bytes[iteration] = local_size; + } + } + + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + state_.is_initialized = true; + + // Instantiate graph + Status status = construct_graph(launch_with_pdl); + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + Status + construct_graph(bool launch_with_pdl) { +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) + Status status = Status::kSuccess; + + // Destroy existing graph, if created + if (state_.graph_created) { + status = detail::check_cuda_status(cudaGraphDestroy(state_.graph)); + if (status != Status::kSuccess) { + return status; + } + } + + state_.graph_created = true; + + cudaGraphNode_t full_barrier_node; + + // Create dummy stream + cudaStream_t stream; + status = detail::check_cuda_status(cudaStreamCreate(&stream)); + if (status != Status::kSuccess) { + return status; + } + + // Create graph + status = detail::check_cuda_status(cudaGraphCreate(&state_.graph, 0)); + if (status != Status::kSuccess) { + return status; + } + + // 1. Full barrier node + status = detail::check_cuda_status(cudaStreamBeginCaptureToGraph( + stream, + state_.graph, + nullptr, nullptr, 0, + cudaStreamCaptureModeRelaxed)); + if (status != Status::kSuccess) { + return status; + } + + cutlass::Array self_flag_ptrs; + for (int iteration = 0; iteration < TP_; ++iteration) { + self_flag_ptrs[iteration] = state_.params_array[iteration].distributed.self_flag_ptr_; + } + + launch_full_barrier( + state_.device_barrier_ptrs, self_flag_ptrs, state_.device_idx, stream, launch_with_pdl); + + status = detail::check_cuda_status(cudaStreamEndCapture(stream, &state_.graph)); + if (status != Status::kSuccess) { + return status; + } + + size_t num_nodes; + status = detail::check_cuda_status(cudaGraphGetNodes(state_.graph, nullptr, &num_nodes)); + if (status != Status::kSuccess) { + return status; + } + if (num_nodes != 1) { + CUTLASS_TRACE_HOST(" construct_graph() failure: expected a single node in the graph, got " << num_nodes << "."); + return Status::kErrorInternal; + } + if (status != Status::kSuccess) { + return status; + } + status = detail::check_cuda_status(cudaGraphGetNodes(state_.graph, &full_barrier_node, &num_nodes)); + if (status != Status::kSuccess) { + return status; + } + + // 2. Optional mem copy branch + if constexpr (HasMemcpy) { + + status = detail::check_cuda_status(cudaStreamBeginCaptureToGraph( + stream, + state_.graph, + &full_barrier_node, + /* dependencyData = */ nullptr, + 1, + cudaStreamCaptureModeRelaxed)); + + if (status != Status::kSuccess) { + return status; + } + + // No copies for first iter; we assume the data is already there. + for (int iteration = 1; iteration < TP_; ++iteration) { + + status = detail::check_cuda_status(cudaMemcpyAsync( + state_.memcpy_source_ptr_array[iteration], + state_.memcpy_remote_ptr_array[iteration], + state_.memcpy_bytes[iteration], + cudaMemcpyDeviceToDevice, stream)); + + if (status != Status::kSuccess) { + return status; + } + + // Set flag to non zero + status = detail::check_cuda_status(cudaMemsetAsync( + reinterpret_cast(state_.params_array[iteration].distributed.peer_flag_ptr_), + 0b11111111, + sizeof(ElementFlag), + stream)); + + if (status != Status::kSuccess) { + return status; + } + } + + status = detail::check_cuda_status(cudaStreamEndCapture(stream, &state_.graph)); + if (status != Status::kSuccess) { + return status; + } + } + + // 3. Run local GEMMs + // 3.1. Create edge between full barrier and the correct gemm stage/iteration + cudaGraphEdgeData barrier_to_gemm_edge = {}; + barrier_to_gemm_edge.from_port = HasMemcpy ? cudaGraphKernelNodePortLaunchCompletion: cudaGraphKernelNodePortProgrammatic; + barrier_to_gemm_edge.type = cudaGraphDependencyTypeProgrammatic; + + status = detail::check_cuda_status(cudaStreamBeginCaptureToGraph( + stream, + state_.graph, + &full_barrier_node, + /* dependencyData = */ &barrier_to_gemm_edge, + 1, + cudaStreamCaptureModeRelaxed)); + if (status != Status::kSuccess) { + return status; + } + + for (int iteration = 0; iteration < TP_; ++iteration) { + status = DeviceGemm::run( + state_.params_array[iteration], + stream, + /* cuda_adapter = */ nullptr, + /* launch_with_pdl = */ launch_with_pdl); + + if (status != Status::kSuccess) { + return status; + } + } + + status = detail::check_cuda_status(cudaStreamEndCapture(stream, &state_.graph)); + if (status != Status::kSuccess) { + return status; + } + + // 4. Cleanup. + //// Destroy dummy stream + status = detail::check_cuda_status(cudaStreamDestroy(stream)); + if (status != Status::kSuccess) { + return status; + } + + // 5. Instantiate graph + status = detail::check_cuda_status(cudaGraphInstantiate( + &state_.graph_executable, + state_.graph, + /* flags = */ 0)); + if (status != Status::kSuccess) { + return status; + } + state_.graph_instantiated = true; + + return Status::kSuccess; +#else + CUTLASS_TRACE_HOST(" construct_graph() failure: target was compiled with an incompatible " << + "version of the CUDA toolkit. Please compile Distributed GEMM with CUDA toolkit 12.4 or later."); + return Status::kErrorInternal; +#endif + } + + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST(" DistributedGemm does not support updating arguments yet."); + return Status::kErrorInternal; + } + + // NOTE: the interface for run() is different in Distributed Gemm: + // 1. launch_with_pdl is specified in `initialize`, where the cuda graph is being constructed, + // 2. the state of distributed gemm is an array of params for different iterations, and a + // cuda graph. + // 3. Custom cuda adapters aren't supported for simplicity. + static Status + run(DistributedGemmState& state, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("DistributedGemm::run()"); + + if (not state.is_initialized) { + CUTLASS_TRACE_HOST(" Distributed gemm was not initialized. Did you forget to call initialize()?"); + return Status::kErrorInternal; + } + + if (not state.graph_instantiated) { + CUTLASS_TRACE_HOST(" Distributed gemm graph was not instantiated. Did you forget to call initialize()/construct_graph()?"); + return Status::kErrorInternal; + } + + cudaError_t result = cudaGraphLaunch(state.graph_executable, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaGraphLaunch() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr) { + return run(state_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(state_, stream); + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const* args, + void** workspace_ptrs, + void** exclusive_workspace_ptrs, + int device_idx, + cudaStream_t stream = nullptr) { + Status status = initialize( + args, + workspace_ptrs, + exclusive_workspace_ptrs, + device_idx, + stream); + + if (Status::kSuccess == status) { + status = run(stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const* args, + void** workspace_ptrs, + void** exclusive_workspace_ptrs, + int device_idx, + cudaStream_t stream = nullptr) { + return run( + args, + workspace_ptrs, + exclusive_workspace_ptrs, + device_idx, + stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::distributed::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab91cf890a0e544d685689e7081cf904e626813d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/device/full_barrier.hpp @@ -0,0 +1,74 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Device layer interface for Distributed GEMM barrier kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/experimental/distributed/kernel/full_barrier.hpp" + +namespace cutlass::distributed::device { + +template +void launch_full_barrier( + cutlass::Array device_arrival_ptrs, + cutlass::Array iteration_flag_ptrs, + IntType device_idx, + cudaStream_t stream, + bool launch_with_pdl) { + +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) + // Legacy (kernel) launch with PDL + cudaLaunchAttribute attributes[1]; + attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attributes[0].val.programmaticStreamSerializationAllowed = 1; + + cudaLaunchConfig_t launch_config; + launch_config.gridDim = 1; + launch_config.blockDim = 1; + launch_config.dynamicSmemBytes = 0; + launch_config.stream = stream; + launch_config.attrs = attributes; + launch_config.numAttrs = launch_with_pdl ? 1 : 0; + + cudaLaunchKernelEx( + &launch_config, + cutlass::distributed::kernel::full_barrier_kernel, + device_arrival_ptrs, + iteration_flag_ptrs, + device_idx); +#endif +} + +} // namespace cutlass::distributed::device + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0445567ee4dd67cb8f0139fe3ae6a16291b5689a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/detail.hpp @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Distributed gemm kernel layer helpers. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::distributed::kernel::detail { + +// Ld with CV cache hint (don’t cache and fetch again) +// Reference: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators +// Used for loading arrival counts from peer devices + +CUTLASS_DEVICE +void ld_without_cache(uint64_t& val, void const * ptr) { + asm volatile( + "{\n" + " ld.global.cv.u64 %0, [%1];\n" + "}\n" + : "=l"(val) + : "l"(ptr)); +} + +CUTLASS_DEVICE +void ld_without_cache(uint32_t& val, void const * ptr) { + asm volatile( + "{\n" + " ld.global.cv.u32 %0, [%1];\n" + "}\n" + : "=r"(val) + : "l"(ptr)); +} + +} // namespace cutlass::distributed::kernel::detail + +/////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b29003104508dd6ad1cecaa43aaf38fdba017463 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/dist_gemm_kernel_wrapper.hpp @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file Distributed GEMM Kernel Wrapper + + Prepends CUTLASS 3 GEMM kernels with barriers and other necessary instructions to exectue + a Distributed GEMM stage. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/experimental/distributed/kernel/detail.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::distributed::kernel { + +namespace detail { + +// Allow all CUTLASS 3.X GEMM kernels +template +struct SupportsDistributedGemm: cutlass::gemm::detail::IsCutlass3GemmKernel {}; + +} // namespace detail + +/*! + DistributedGemmKernelWrapper is a wrapper around a GEMM kernel. + + Depending on the underlying distribution policy/schedule, it prepends the underlying local GEMM + kernel with a few additional instructions that gate the execution of the GEMM on buffers being + ready for stages/iterations > 0. +*/ + +template +struct DistributedGemmKernelWrapper; + +template +struct DistributedGemmKernelWrapper< + GemmKernel_, + DistSchedule_, + cute::enable_if_t::value> + >: GemmKernel_ +{ + using DistSchedule = DistSchedule_; + using TP = typename DistSchedule::TP; + + static constexpr bool KernelWritesArrivalFlag = DistSchedule::KernelWritesArrivalFlag; + + using BaseKernel = GemmKernel_; + using BaseArguments = typename BaseKernel::Arguments; + using BaseParams = typename BaseKernel::Params; + + //static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now."); + static_assert(not cute::is_same_v, "DistributedGEMM epilogues must have a source."); + + using ElementFlag = uint32_t; + + // Device side arguments + struct DistributedArguments { + int device_idx = 0; + int iteration = 0; + + void* self_flag_ptr{nullptr}; + void* peer_flag_ptr{nullptr}; + }; + + struct PackedArguments { + BaseArguments base{}; + DistributedArguments distributed{}; + }; + + struct DistributedParams { + int device_idx = 0; + int iteration = 0; + + ElementFlag* self_flag_ptr_{nullptr}; + ElementFlag* peer_flag_ptr_{nullptr}; + }; + + // Kernel entry point API + struct PackedParams { + BaseParams base{}; + DistributedParams distributed{}; + }; + + using Params = PackedParams; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + PackedParams + to_underlying_arguments(PackedArguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("distributed::to_underlying_arguments():"); + + auto kernel_params = BaseKernel::to_underlying_arguments(args.base, workspace); + + DistributedParams dist_params = { + args.distributed.device_idx, + args.distributed.iteration, + reinterpret_cast(args.distributed.self_flag_ptr), + reinterpret_cast(args.distributed.peer_flag_ptr) + }; + + return {kernel_params, dist_params}; + } + + static bool + can_implement(BaseArguments const& args) { + return BaseKernel::can_implement(args); + } + + static bool + can_implement(PackedArguments const& args) { + return BaseKernel::can_implement(args.base); + } + + static size_t + get_workspace_size(BaseArguments const& args) { + return BaseKernel::get_workspace_size(args); + } + + static size_t + get_workspace_size(PackedArguments const& args) { + return BaseKernel::get_workspace_size(args.base); + } + + static cutlass::Status + initialize_workspace(BaseArguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return BaseKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + } + + static cutlass::Status + initialize_workspace(PackedArguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return BaseKernel::initialize_workspace(args.base, workspace, stream, cuda_adapter); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(PackedParams const& params) { + return BaseKernel::get_grid_shape(params.base); + } + + static dim3 + get_grid_shape(BaseParams const& params) { + return BaseKernel::get_grid_shape(params); + } + + CUTLASS_DEVICE + void + barrier_buffer(PackedParams const& params) { + if (params.distributed.iteration > 0) { + + ElementFlag comm_iter = 0; + detail::ld_without_cache(comm_iter, params.distributed.self_flag_ptr_); + while (comm_iter == 0) { + detail::ld_without_cache(comm_iter, params.distributed.self_flag_ptr_); + __nanosleep(40); + } + + } + } + + CUTLASS_DEVICE + void + maybe_signal_arrival(PackedParams const& params) { + if constexpr (KernelWritesArrivalFlag) { + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + params.distributed.iteration > 0) { + *reinterpret_cast(params.distributed.peer_flag_ptr_) = 1; + } + } + } + + CUTLASS_DEVICE + void + operator()(PackedParams const& params, char* smem_buf) { + // Launch next grid as soon as possible + arch::launch_dependent_grids(); + + // Wait on previous kernels to flush their memory. + arch::wait_on_dependent_grids(); + + // Optionally write arrivals for the previous stage/iteration. + maybe_signal_arrival(params); + + // Spin-wait on an arrival flag, make sure the respective buffers are ready. + // If the buffered operand is memcpied into, it would wait on its local flag. + // If it's a remote buffer that is accessed directly, it would wait on its remote flag. + barrier_buffer(params); + + // Perform local gemm + BaseKernel gemm; + gemm(params.base, smem_buf); + } + +}; + +} // namespace cutlass::distributed::kernel + +/////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0ec620a536f258dea265a4e6c7fd55ee7a3168be --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/kernel/full_barrier.hpp @@ -0,0 +1,82 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Distributed GEMM barrier kernel. + + The kernel resets the per-stage arrival flags, performs a full barrier (any-to-any), + and also atomically resets the local barrier arrival count. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/grid_dependency_control.h" + +#include "cutlass/experimental/distributed/kernel/detail.hpp" + +namespace cutlass::distributed::kernel { + +template +__global__ void full_barrier_kernel( + cutlass::Array device_arrival_ptrs, + cutlass::Array iteration_flag_ptrs, + IntType device_idx) { + + arch::launch_dependent_grids(); + arch::wait_on_dependent_grids(); + + CUTLASS_PRAGMA_UNROLL + for (FlagType i = 0; i < Iterations; ++i) { + iteration_flag_ptrs[i][0] = static_cast(0); + } + + IntType val = 1; + IntType max_val = static_cast(NP - 1); + + CUTLASS_PRAGMA_UNROLL + for (IntType d = 0; d < NP; ++d) { + if (d != device_idx) { + atomicAdd(device_arrival_ptrs[d], val); + } + } + + IntType curr_val = 0; + detail::ld_without_cache(curr_val, device_arrival_ptrs[device_idx]); + while (curr_val < max_val) { + __nanosleep(40); + detail::ld_without_cache(curr_val, device_arrival_ptrs[device_idx]); + } + + atomicSub(device_arrival_ptrs[device_idx], max_val); +} + +} // namespace cutlass::distributed::kernel + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73d52adcbb457f71a51c30a41a08bc787777c7d7 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_1d_schedules.hpp @@ -0,0 +1,324 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file 1-D Distributed GEMM Schedules + + NOTE: This API is __experimental__ and will change heavily over time. Particularly the use of + CuTe layouts as integer functions in defining iteration-to-tile mappings is over-expressive and + leaves plenty of room for incorrect/unexpected behavior. + Please proceed with caution when modifying these schedules or defining new ones. + + Device/iteration mappings are defined with CuTe layouts, + since they are functions from integers to integers as well. + + Each mapping is defined as a linear function of 2 variables (rank-2 layout): + First variable (mode) is device index, second variable (mode) is iteration. + A constant is also added to the final result as an offset value. This is a temporary workaround + so that identity ownership mappings in the final iteration can be guaranteed for the schedules + currently implemented. + How are these mappings defined? + Each schedule represents a unique parallel matrix multiplication algorithm, which describes how + matrices/tensors are distributed among TP GPUs. + + Depending on the algorithm, access patterns (GPU to tile or (GPU, iteration) to tile) mappings) + are not necessarily going to be the identity function. + + Pitfalls: + The current representation uses CuTe layouts as arbitrary linear functions that map + (GPU, iteration) to tile indices. + This approach is over-expressive, and therefore makes a lot of assumptions on the part of the + developer in how these mappings are defined. This can easily lead to incorrect implementations + if not handled carefully. + + + Assumption made in all schedules: TP == number of iterations (stages) +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +#include "cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::distributed::schedules { + +// GEMM + Reduce Scatter +// A and B are tiled along the K mode, which means each GPU gets an [M, K / TP]-shaped slice of A, +// and an [N, K / TP] slice of B. +// A is further tiled along the M mode, so that each stage/iteration computes a GEMM of shape +// [M / TP, N, K / TP], and the epilogue will perform the reduction by reading its C tensor directly +// from the left peer's previous D buffer. +// +// Below is an illustration of the tiling and iteration mappings for this pattern in the TP=4 case: +// +// Rows correspond to the M mode, columns correspond to the K mode for A and B and N mode for +// C and D. Because sharding is done along K, each column of tiles is owned by one GPU. +// Values in the grid correspond to the iteration/stage accessing the tile. +// * means the same tile is accessed in all iterations/stages. +// +// Tensor A Tensor B +// +// GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3 +// |-----|-----|-----|-----| |-----|-----|-----|-----| +// | | | | | | | | | | +// | 3 | 0 | 1 | 2 | | | | | | +// |_____|_____|_____|_____| | | | | | +// | | | | | | | | | | +// | 2 | 3 | 0 | 1 | | | | | | +// |_____|_____|_____|_____| | * | * | * | * | +// | | | | | | | | | | +// | 1 | 2 | 3 | 0 | | | | | | +// |_____|_____|_____|_____| | | | | | +// | | | | | | | | | | +// | 0 | 1 | 2 | 3 | | | | | | +// |_____|_____|_____|_____| |_____|_____|_____|_____| +// +// M x K N x K +// +// +// Tensor C Tensor D +// (Peer's D) +// +// +// |-----------------------| |-----------------------| +// | | | | +// GPU0 | 1,2,3 | GPU0 | * | +// |_______________________| |_______________________| +// | | | | +// GPU1 | 1,2,3 | GPU1 | * | +// |_______________________| |_______________________| +// | | | | +// GPU2 | 1,2,3 | GPU2 | * | +// |_______________________| |_______________________| +// | | | | +// GPU3 | 1,2,3 | GPU3 | * | +// |_______________________| |_______________________| +// +// M x N M x N +// +// +// Tensor A's access pattern can be expressed as follows as a function of GPU index and iteration: +// tile_idx = ((device_idx - 1) - iter + TP) % TP +// +// and can be expressed with the following CuTe layout: +// (TP, TP) : (1, -1) +// with ProcessorOffset = -1 +// +// +// Note: Since this schedule does not expose any communication, iteration 0 has no reduction step, +// therefore epilogue is sourceless in iteration 0, and in the rest of the iterations the epilogue +// source is a remote pointer to Tensor D owned by its left peer. +// +// Left peer is simply (device_idx - 1 + TP) % TP, which is expressed with the following CuTe layout: +// (TP, TP) : (1, 0) +// +template +struct ReduceScatter1D_TilingA_RotatingC: BaseSchedule< + TP_, + /* ProcessorTiler_ = */ cute::Shape<_1, _1, TP_, _1>, + /* IterationTiler_ = */ cute::Shape, + /* PeerDeviceMapping_ = */ cute::Layout, cute::Stride<_1, _0>>, // (left neighbor) = (device_idx + ProcessorOffset + TP) % TP, with ProcessorOffset = -1 + /* IterationMappingM_ = */ cute::Layout, cute::Stride<_1, _m1>>, // = (device_idx + ProcessorOffset - iter + TP) % TP, with ProcessorOffset = -1 + /* IterationMappingN_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::N == 1) = 0 + /* IterationMappingK_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::K == 1) = 0 + /* IterationMappingL_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::L == 1) = 0 + /* ProcessorOffset_ = */ _m1, + /* MemcpyA_ = */ false, + /* MemcpyB_ = */ false, + /* KernelWritesArrivalFlag_ = */ true, + /* NumBuffersA_ = */ 0, + /* NumBuffersB_ = */ 0, + /* NumBuffersC_ = */ 0, + /* NumBuffersD_ = */ TP_{} - 1> {}; + +// This schedule is similar to ReduceScatter1D_TilingA_RotatingC, but with the second tiling +// done along N instead of M. All other details remain unchanged. +template +struct ReduceScatter1D_TilingB_RotatingC: BaseSchedule< + TP_, + /* ProcessorTiler_ = */ cute::Shape<_1, _1, TP_, _1>, + /* IterationTiler_ = */ cute::Shape<_1, TP_, _1, _1>, + /* PeerDeviceMapping_ = */ cute::Layout, cute::Stride<_1, _0>>, // (left neighbor) = (device_idx + ProcessorOffset + TP) % TP, with ProcessorOffset = -1 + /* IterationMappingM_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::N == 1) = 0 + /* IterationMappingN_ = */ cute::Layout, cute::Stride<_1, _m1>>, // = (device_idx + ProcessorOffset - iter + TP) % TP, with ProcessorOffset = -1 + /* IterationMappingK_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::K == 1) = 0 + /* IterationMappingL_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::L == 1) = 0 + /* ProcessorOffset_ = */ _m1, + /* MemcpyA_ = */ false, + /* MemcpyB_ = */ false, + /* KernelWritesArrivalFlag_ = */ true, + /* NumBuffersA_ = */ 0, + /* NumBuffersB_ = */ 0, + /* NumBuffersC_ = */ 0, + /* NumBuffersD_ = */ TP_{} - 1> {}; + + +// AllGather + GEMM +// A and B are tiled along the N mode, which means each GPU allgathers A, +// and operates with an [N / TP, K] slice of B. +// For pipelining, A is further tiled along the M mode, so that each stage/iteration computes a +// GEMM of shape [M / TP, N / TP, K], and concurrently we copy a peer's A slice into a local buffer +// for the next stage/iteration. +// +// Below is an illustration of the tiling and iteration mappings for this pattern in the TP=4 case: +// +// Rows correspond to the M mode, columns correspond to the K mode for A and B and N mode for +// C and D. +// +// Since this is a pipelined schedule without exposed communication, the first iteration starts +// off immediately and operates on local slices of A and B. In the rest of the iterations, each +// GPU accesses a slice of A copied from a peer GPU while it was busy with the last stage. +// +// Values in the following grids correspond to the peer buffer accessed by each GPU during +// different iterations: +// +// Tensor A Tensor A +// iter 0 iter 1 +// +// |-----------------------| |-----------------------| +// | | | | +// GPU0 | 0 | | 1 | +// |_______________________| |_______________________| +// | | | | +// GPU1 | 1 | | 2 | +// |_______________________| |_______________________| +// | | | | +// GPU2 | 2 | | 3 | +// |_______________________| |_______________________| +// | | | | +// GPU3 | 3 | | 0 | +// |_______________________| |_______________________| +// +// M x K M x K +// +// Tensor A Tensor A +// iter 2 iter 3 +// +// |-----------------------| |-----------------------| +// | | | | +// GPU0 | 2 | | 3 | +// |_______________________| |_______________________| +// | | | | +// GPU1 | 3 | | 0 | +// |_______________________| |_______________________| +// | | | | +// GPU2 | 0 | | 1 | +// |_______________________| |_______________________| +// | | | | +// GPU3 | 1 | | 2 | +// |_______________________| |_______________________| +// +// M x K M x K +// +// Values in the following grids correspond to the tile accessed during each iteration. +// * means the same tile is accessed in all iterations/stages. +// +// Tensor B Tensor C/D +// +// +// |-----------------------| |-----|-----|-----|-----| +// | | | | | | | +// GPU0 | * | GPU0 | 0 | 1 | 2 | 3 | +// |_______________________| |_____|_____|_____|_____| +// | | | | | | | +// GPU1 | * | GPU1 | 3 | 0 | 1 | 2 | +// |_______________________| |_____|_____|_____|_____| +// | | | | | | | +// GPU2 | * | GPU2 | 2 | 3 | 0 | 1 | +// |_______________________| |_____|_____|_____|_____| +// | | | | | | | +// GPU3 | * | GPU3 | 1 | 2 | 3 | 0 | +// |_______________________| |_____|_____|_____|_____| +// +// N x K M x N +// +// +// Tensor C/D's access pattern can be expressed as follows as a function of GPU index and iteration: +// tile_idx = (device_idx + iter) % TP +// +// and can be expressed with the following CuTe layout: +// (TP, TP) : (1, 1) +// +// This schedule does not need a ProcessorOffset constant. +// +// Peer devices from which A slices are copied is also expressed with the same function and CuTe +// layout. +// +template +struct AllGather1D_TilingCD_RotatingA: BaseSchedule< + TP_, + /* ProcessorTiler_ = */ cute::Shape<_1, TP_, _1, _1>, + /* IterationTiler_ = */ cute::Shape, + /* PeerDeviceMapping_ = */ cute::Layout, cute::Stride<_1, _1>>, // = device_idx + iter + /* IterationMappingM_ = */ cute::Layout, cute::Stride<_1, _1>>, // = device_idx + iter + /* IterationMappingN_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::N == 1) = 0 + /* IterationMappingK_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::K == 1) = 0 + /* IterationMappingL_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::L == 1) = 0 + /* ProcessorOffset_ = */ _0, + /* MemcpyA_ = */ true, + /* MemcpyB_ = */ false, + /* KernelWritesArrivalFlag_ = */ false, + /* NumBuffersA_ = */ TP_{} - 1, + /* NumBuffersB_ = */ 0, + /* NumBuffersC_ = */ 0, + /* NumBuffersD_ = */ 0>{}; + +// This schedule is similar to AllGather1D_TilingCD_RotatingA, but with the order of tiling +// swapped from N then M to M then N. This means slices of B are rotated around GPUs instead of +// slices of A. All other details remain unchanged. +template +struct AllGather1D_TilingCD_RotatingB: BaseSchedule< + TP_, + /* ProcessorTiler_ = */ cute::Shape, + /* IterationTiler_ = */ cute::Shape<_1, TP_, _1, _1>, + /* PeerDeviceMapping_ = */ cute::Layout, cute::Stride<_1, _1>>, // = device_idx + iter + /* IterationMappingM_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::M == 1) = 0 + /* IterationMappingN_ = */ cute::Layout, cute::Stride<_1, _1>>, // = device_idx + iter + /* IterationMappingK_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::K == 1) = 0 + /* IterationMappingL_ = */ cute::Layout, cute::Stride<_0, _0>>, // (IterationTiler::L == 1) = 0 + /* ProcessorOffset_ = */ _0, + /* MemcpyA_ = */ false, + /* MemcpyB_ = */ true, + /* KernelWritesArrivalFlag_ = */ false, + /* NumBuffersA_ = */ 0, + /* NumBuffersB_ = */ TP_{} - 1, + /* NumBuffersC_ = */ 0, + /* NumBuffersD_ = */ 0>{}; + + +} // namespace cutlass::distributed::schedules + +/////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a2d33281379f71b504f7303637e410c787bba83 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp @@ -0,0 +1,538 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file Base Schedule for Distributed GEMM + + Templates Distributed GEMM schedules so that they can be expressed as a set of CuTe primitives and + other static values. + + NOTE: This API is __experimental__ and will change heavily over time. Particularly the use of + CuTe layouts as integer functions in defining iteration-to-tile mappings is over-expressive and + leaves plenty of room for incorrect/unexpected behavior. + Please proceed with caution when modifying these schedules or defining new ones. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::distributed::schedules { + +/* + * Distributed GEMM schedules define exactly how operand tensors are tiled and sliced across + * processors (GPUs) and stages/iterations. + * + * BaseSchedule's role is to ease the implementation of arbitrary Distributed GEMM schedules + * and reduce code repetition, simply by reducing the implementation to CuTe primitives and a few + * other static values (buffer sizes, whether tensors are rotated using memcpies or not, and the + * like.) + */ +template < + class TP_, // CuTe constant defining the number of processors / GPUs / TP value + class ProcessorTiler_, // CuTe tiler defining how fully materialized tensors are sharded across devices + class IterationTiler_, // CuTe tiler defining how local tensors are tiled across stages/iterations + class PeerDeviceMapping_, // CuTe layout mapping device index and stage/iteration to the device's peer index for that stage/iteration + class IterationMappingM_, // CuTe layout mapping device index and stage/iteration to M tile index + class IterationMappingN_, // CuTe layout mapping device index and stage/iteration to N tile index + class IterationMappingK_, // CuTe layout mapping device index and stage/iteration to K tile index + class IterationMappingL_, // CuTe layout mapping device index and stage/iteration to L tile index + class ProcessorOffset_, // Constant offset for processor / GPU index in iteration mapping + bool MemcpyA_, // Whether tensor A is memcpied + bool MemcpyB_, // Whether tensor B is memcpied + bool KernelWritesArrivalFlag_, // Whether the kernel writes arrival flags (when tensors are directly accessed from peer and not memcpied) + int NumBuffersA_, // Number of buffers required for tensor A + int NumBuffersB_, // Number of buffers required for tensor B + int NumBuffersC_, // Number of buffers required for tensor C + int NumBuffersD_> // Number of buffers required for tensor D +struct BaseSchedule { + + using TP = TP_; + + static_assert( + cute::is_static::value && cute::is_integral::value && cute::rank(TP{}) == 1 && cute::depth(TP{}) == 0, + "Only integers allowed for TP at this time."); + + static_assert(cute::rank(ProcessorTiler_{}) == 4, "Expected rank-4 processor tiler."); + static_assert(cute::rank(IterationTiler_{}) == 4, "Expected rank-4 iteration tiler."); + + static_assert(cute::rank(PeerDeviceMapping_{}) == 2, + "PeerDeviceMapping must be rank-2 (device_idx, iter)"); + + static_assert(cute::rank(IterationMappingM_{}) == 2, + "IterationMappingM must be rank-2 (device_idx, iter)."); + static_assert(cute::rank(IterationMappingN_{}) == 2, + "IterationMappingN must be rank-2 (device_idx, iter)."); + static_assert(cute::rank(IterationMappingK_{}) == 2, + "IterationMappingK must be rank-2 (device_idx, iter)."); + static_assert(cute::rank(IterationMappingL_{}) == 2, + "IterationMappingL must be rank-2 (device_idx, iter)."); + + using ProcessorTiler = ProcessorTiler_; + using IterationTiler = IterationTiler_; + + using PeerDeviceMapping = PeerDeviceMapping_; + using IterationMappingM = IterationMappingM_; + using IterationMappingN = IterationMappingN_; + using IterationMappingK = IterationMappingK_; + using IterationMappingL = IterationMappingL_; + + using ProcessorOffset = ProcessorOffset_; + + static constexpr bool KernelWritesArrivalFlag = KernelWritesArrivalFlag_; + static constexpr bool MemcpyA = MemcpyA_; + static constexpr bool MemcpyB = MemcpyB_; + static constexpr bool HasMemcpy = MemcpyA || MemcpyB; + + static constexpr int NumBuffersA = NumBuffersA_; + static constexpr int NumBuffersB = NumBuffersB_; + static constexpr int NumBuffersC = NumBuffersC_; + static constexpr int NumBuffersD = NumBuffersD_; + + static_assert( + NumBuffersA > 0 ^ + NumBuffersB > 0 ^ + NumBuffersC > 0 ^ + NumBuffersD > 0, + "Only one of the ABCD tensors can be buffered!"); + + static constexpr bool BufferedOutput = NumBuffersC > 0 || NumBuffersD > 0; + static constexpr bool RemoteC = NumBuffersC == 0 && NumBuffersD > 0; + static constexpr bool RemoteD = NumBuffersD == 0 && NumBuffersC > 0; + + static_assert(not RemoteD, "Remote D is not supported yet."); + + // Host-side API: can_implement based on the GLOBAL problem shape + template + static bool + can_implement_global(ProblemShape const& global_problem_shape) { + auto [M, N, K, L] = append<4>(global_problem_shape, 1); + + auto [ptileM, ptileN, ptileK, ptileL] = ProcessorTiler{}; + auto [itileM, itileN, itileK, itileL] = IterationTiler{}; + + auto tileM = ptileM * itileM; + auto tileN = ptileN * itileN; + auto tileK = ptileK * itileK; + auto tileL = ptileL * itileL; + + return M % tileM == 0 && N % tileN == 0 && K % tileK == 0 && L % tileL == 0; + } + + template + CUTLASS_HOST_DEVICE + static auto + get_local_gemm_shape(ProblemShape const& global_problem_shape) { + auto problem_shape_MNKL = append<4>(global_problem_shape, 1); + + return shape_div( + shape_div( + problem_shape_MNKL, + ProcessorTiler{}), + IterationTiler{}); + } + + // Host-side API: determine peers + static auto + get_peers_for_device(int device_idx) { + auto left_peer_id = device_idx > 0 ? device_idx - 1 : TP{} - 1; + auto right_peer_id = device_idx < TP{} - 1 ? device_idx + 1 : 0; + + return cute::make_tuple(left_peer_id, right_peer_id); + } + + // Determines peer given device index and iteration + static int + get_remote_peer_id(int device_idx, int iteration) { + auto device_iter_to_peer_idx = PeerDeviceMapping{}; + auto peer_idx = ( + device_iter_to_peer_idx(device_idx + ProcessorOffset{}, iteration) + TP{} + ) % TP{}; + return peer_idx; + } + + // Construct tilers and index mappers for sharding across processors + template + CUTLASS_HOST_DEVICE + static auto + get_processor_tiler_a(Tensor tensor) { + if constexpr (NumBuffersA > 0) { + return shape_div(tensor.shape(), select<0,2,3>(IterationTiler{})); + } else { + return shape_div(tensor.shape(), select<0,2,3>(ProcessorTiler{})); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_processor_tiler_b(Tensor tensor) { + if constexpr (NumBuffersB > 0) { + return shape_div(tensor.shape(), select<1,2,3>(IterationTiler{})); + } else { + return shape_div(tensor.shape(), select<1,2,3>(ProcessorTiler{})); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_processor_tiler_c(Tensor tensor) { + if constexpr (BufferedOutput) { + return shape_div(tensor.shape(), select<0,1,3>(IterationTiler{})); + } else { + return shape_div(tensor.shape(), select<0,1,3>(ProcessorTiler{})); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_processor_tiler_d(Tensor tensor) { + return get_processor_tiler_c(tensor); + } + + // Construct tilers and index mappers for tiling and iterating on device + template + CUTLASS_HOST_DEVICE + static auto + get_device_tiler_a(Tensor tensor) { + static_assert(NumBuffersA == 0, "Buffered tensors don't have device tilers!"); + return shape_div(tensor.shape(), select<0,2,3>(IterationTiler{})); + } + + template + CUTLASS_HOST_DEVICE + static auto + get_device_tiler_b(Tensor tensor) { + static_assert(NumBuffersB == 0, "Buffered tensors don't have device tilers!"); + return shape_div(tensor.shape(), select<1,2,3>(IterationTiler{})); + } + + template + CUTLASS_HOST_DEVICE + static auto + get_device_tiler_c(Tensor tensor) { + static_assert(NumBuffersC == 0 && NumBuffersD == 0, "Buffered tensors don't have device tilers!"); + return shape_div(tensor.shape(), select<0,1,3>(IterationTiler{})); + } + + template + CUTLASS_HOST_DEVICE + static auto + get_device_tiler_d(Tensor tensor) { + static_assert(NumBuffersC == 0 && NumBuffersD == 0, "Buffered tensors don't have device tilers!"); + return shape_div(tensor.shape(), select<0,1,3>(IterationTiler{})); + } + + // Map device index and iteration to tile coordinate + // Must be implemented by children for now. + CUTLASS_HOST_DEVICE + static auto + get_device_tile_idx_a(int device_idx, int iteration) { + auto mapping_m = IterationMappingM{}; + auto mapping_k = IterationMappingK{}; + auto mapping_l = IterationMappingL{}; + auto crd_m = (mapping_m(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_k = (mapping_k(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_l = (mapping_l(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + return make_coord(crd_m, crd_k, crd_l); + } + + CUTLASS_HOST_DEVICE + static auto + get_device_tile_idx_b(int device_idx, int iteration) { + auto mapping_n = IterationMappingN{}; + auto mapping_k = IterationMappingK{}; + auto mapping_l = IterationMappingL{}; + auto crd_n = (mapping_n(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_k = (mapping_k(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_l = (mapping_l(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + return make_coord(crd_n, crd_k, crd_l); + } + + CUTLASS_HOST_DEVICE + static auto + get_device_tile_idx_c(int device_idx, int iteration) { + auto mapping_m = IterationMappingM{}; + auto mapping_n = IterationMappingN{}; + auto mapping_l = IterationMappingL{}; + auto crd_m = (mapping_m(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_n = (mapping_n(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_l = (mapping_l(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + return make_coord(crd_m, crd_n, crd_l); + } + + CUTLASS_HOST_DEVICE + static auto + get_device_tile_idx_d(int device_idx, int iteration) { + auto mapping_m = IterationMappingM{}; + auto mapping_n = IterationMappingN{}; + auto mapping_l = IterationMappingL{}; + auto crd_m = (mapping_m(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_n = (mapping_n(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + auto crd_l = (mapping_l(device_idx + ProcessorOffset{}, iteration) + TP{}) % TP{}; + return make_coord(crd_m, crd_n, crd_l); + } + + // Device Partitioners: partition non-buffered processor-resident operands. + // Processor-resident operands fall into two categories: buffered, and not buffered. + // Those buffered aren't expected to be further partitioned, and those + template + static auto + get_tensor_A(Tensor original_tensor, void * tensor_buffer_ptr, int device_idx, int iteration) { + static_assert(rank(original_tensor) == 3); + + using Element = typename Tensor::value_type; + // Recreate tensor without constness. This is to ensure return types match. + Element* ptr = const_cast(original_tensor.data()); + auto shape = original_tensor.shape(); + auto layout = original_tensor.layout(); + auto tensor = make_tensor(ptr, layout); + + if constexpr (NumBuffersA == 0) { + auto tiler = get_device_tiler_a(tensor); + auto idx = get_device_tile_idx_a(device_idx, iteration); + return inner_partition(tensor, tiler, idx); + } else { + Element* ptr_buffer = reinterpret_cast(tensor_buffer_ptr); + if (iteration == 0) { + return tensor; + } + ptr_buffer += size(shape) * (iteration - 1); + + return make_tensor(ptr_buffer, layout); + } + } + + template + static auto + get_tensor_B(Tensor original_tensor, void * tensor_buffer_ptr, int device_idx, int iteration) { + static_assert(rank(original_tensor) == 3); + + using Element = typename Tensor::value_type; + // Recreate tensor without constness. This is to ensure return types match. + Element * ptr = const_cast(original_tensor.data()); + auto shape = original_tensor.shape(); + auto layout = original_tensor.layout(); + auto tensor = make_tensor(ptr, layout); + + if constexpr (NumBuffersB == 0) { + auto tiler = get_device_tiler_b(tensor); + auto idx = get_device_tile_idx_b(device_idx, iteration); + return inner_partition(tensor, tiler, idx); + } else { + Element * ptr_buffer = reinterpret_cast(tensor_buffer_ptr); + if (iteration == 0) { + return tensor; + } + ptr_buffer += size(shape) * (iteration - 1); + + return make_tensor(ptr_buffer, layout); + } + } + + template + static auto + get_tensor_C(Tensor original_tensor, void * tensor_buffer_ptr, int device_idx, int iteration) { + static_assert(rank(original_tensor) == 3); + + using Element = typename Tensor::value_type; + // Recreate tensor without constness. This is to ensure return types match. + Element * ptr = const_cast(original_tensor.data()); + auto shape = original_tensor.shape(); + auto layout = original_tensor.layout(); + auto tensor = make_tensor(ptr, layout); + + if constexpr (not BufferedOutput) { + auto tiler = get_device_tiler_c(tensor); + auto idx = get_device_tile_idx_c(device_idx, iteration); + return inner_partition(tensor, tiler, idx); + } else { + // implement Remote D + static_assert(RemoteC, ""); + + Element * ptr_buffer = reinterpret_cast(tensor_buffer_ptr); + if (iteration == 0) { + return tensor; + } + ptr_buffer += size(shape) * (iteration - 1); + + return make_tensor(ptr_buffer, layout); + } + } + + template + static auto + get_tensor_D(Tensor original_tensor, void * tensor_buffer_ptr, int device_idx, int iteration) { + static_assert(rank(original_tensor) == 3); + + using Element = typename Tensor::value_type; + // Recreate tensor without constness. This is to ensure return types match. + Element * ptr = const_cast(original_tensor.data()); + auto shape = original_tensor.shape(); + auto layout = original_tensor.layout(); + auto tensor = make_tensor(ptr, layout); + + if constexpr (not BufferedOutput) { + auto tiler = get_device_tiler_d(tensor); + auto idx = get_device_tile_idx_d(device_idx, iteration); + return inner_partition(tensor, tiler, idx); + } else { + // implement Remote D + static_assert(RemoteC, ""); + + Element * ptr_buffer = reinterpret_cast(tensor_buffer_ptr); + // last iteration is the local tensor, the rest are buffers + if (iteration == TP{} - 1) { + return tensor; + } + ptr_buffer += size(shape) * iteration; // note: iteration, not iteration - 1 + + return make_tensor(ptr_buffer, layout); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_local_a_shape(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + if constexpr (NumBuffersA == 0) { + return shape_div( + select<0,2,3>(problem_shape_MNKL), + select<0,2,3>(ProcessorTiler{})); + } else { + return shape_div( + shape_div( + select<0,2,3>(problem_shape_MNKL), + select<0,2,3>(ProcessorTiler{})), + select<0,2,3>(IterationTiler{})); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_local_b_shape(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + if constexpr (NumBuffersB == 0) { + return shape_div( + select<1,2,3>(problem_shape_MNKL), + select<1,2,3>(ProcessorTiler{})); + } else { + return shape_div( + shape_div( + select<1,2,3>(problem_shape_MNKL), + select<1,2,3>(ProcessorTiler{})), + select<1,2,3>(IterationTiler{})); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_local_c_shape(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + if constexpr (not BufferedOutput) { + return shape_div( + select<0,1,3>(problem_shape_MNKL), + select<0,1,3>(ProcessorTiler{})); + } else { + return shape_div( + shape_div( + select<0,1,3>(problem_shape_MNKL), + select<0,1,3>(ProcessorTiler{})), + select<0,1,3>(IterationTiler{})); + } + } + + template + CUTLASS_HOST_DEVICE + static auto + get_local_d_shape(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + if constexpr (not BufferedOutput) { + return shape_div( + select<0,1,3>(problem_shape_MNKL), + select<0,1,3>(ProcessorTiler{})); + } else { + return shape_div( + shape_div( + select<0,1,3>(problem_shape_MNKL), + select<0,1,3>(ProcessorTiler{})), + select<0,1,3>(IterationTiler{})); + } + } + + // Host-side APIs: get_device_slice_{A,B,C,D} + // Slice off a view of the GLOBAL tensor that corresponds to the shard that + // is going to be owned by a specific device. This helps with the initial + // distribution of the GLOBAL operands among devices. + template + static auto + get_device_slice_A(Tensor tensor, int device_idx) { + auto tiler = get_processor_tiler_a(tensor); + return inner_partition(tensor, tiler, device_idx); + } + + template + static auto + get_device_slice_B(Tensor tensor, int device_idx) { + auto tiler = get_processor_tiler_b(tensor); + return inner_partition(tensor, tiler, device_idx); + } + + template + static auto + get_device_slice_C(Tensor tensor, int device_idx) { + auto tiler = get_processor_tiler_c(tensor); + return inner_partition(tensor, tiler, device_idx); + } + + template + static auto + get_device_slice_D(Tensor tensor, int device_idx) { + auto tiler = get_processor_tiler_d(tensor); + return inner_partition(tensor, tiler, device_idx); + } +}; + + + +} // namespace cutlass::gemm::distributed + +/////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/fast_math.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/fast_math.h new file mode 100644 index 0000000000000000000000000000000000000000..eb14856f081f26b591cd4524b55f1cfadca245a7 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/fast_math.h @@ -0,0 +1,1085 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#endif +#if !defined(__QNX__) +#include CUDA_STD_HEADER(utility) +#endif +#include "cutlass/array.h" +#include "cutlass/uint128.h" +#include "cutlass/coord.h" +#include "cutlass/half.h" + +/** + * \file + * \brief Math utilities + */ + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__QNX__) +using ::cuda::std::swap; +#else +template +CUTLASS_HOST_DEVICE void swap(T &lhs, T &rhs) { + T tmp = lhs; + lhs = rhs; + rhs = tmp; +} +#endif + +/****************************************************************************** + * Static math utilities + ******************************************************************************/ + +/// Mixed precision dot product +template +CUTLASS_HOST_DEVICE LongIndex dot( + Coord const &coord, + Coord const &stride, + LongIndex acc = LongIndex()) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; ++n) { + acc += LongIndex(coord[n]) * stride[n]; + } + return acc; +} + +/** + * Statically determine if N is a power-of-two + */ +template +struct is_pow2 { + static bool const value = ((N & (N - 1)) == 0); +}; + +/** + * Statically determine log2(N), rounded down + */ +template +struct log2_down { + /// Static logarithm value + enum { value = log2_down> 1), Count + 1>::value }; +}; + +// Base case +template +struct log2_down { + enum { value = Count }; +}; + +/** + * Statically determine log2(N), rounded up + */ +template +struct log2_up { + /// Static logarithm value + enum { value = log2_up> 1), Count + 1>::value }; +}; + +// Base case +template +struct log2_up { + enum { value = ((1 << Count) < N) ? Count + 1 : Count }; +}; + +/** + * Statically estimate sqrt(N) to the nearest power-of-two + */ +template +struct sqrt_est { + enum { value = 1 << (log2_up::value / 2) }; +}; + +/** + * For performing a constant-division with a compile-time assertion that the + * Divisor evenly-divides the Dividend. + */ +template +struct divide_assert { + enum { value = Dividend / Divisor }; + + static_assert((Dividend % Divisor == 0), "Not an even multiple"); +}; + +/****************************************************************************** + * Rounding + ******************************************************************************/ + +/** + * Round dividend up to the nearest multiple of divisor + */ +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +dividend_t round_nearest(dividend_t dividend, divisor_t divisor) { + return ((dividend + divisor - 1) / divisor) * divisor; +} + +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t abs_for_integer(value_t a) { + return ((a > value_t{0}) ? a : -a); +} +/** + * Greatest common divisor + */ +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t gcd(value_t a, value_t b) { + for (;;) { + if (a == value_t{0}) return cutlass::abs_for_integer(b); + b %= a; + if (b == value_t{0}) return cutlass::abs_for_integer(a); + a %= b; + } +} + +/** + * Least common multiple + */ +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t lcm(value_t a, value_t b) { + value_t temp = cutlass::gcd(a, b); + return (temp != value_t{0}) ? value_t(cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : value_t{}; +} + +/** + * Greatest common divisor + */ +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t gcd_cxx11(value_t a, value_t b) { + return (a == value_t{0} || b == value_t{0}) ? cutlass::abs_for_integer(a | b) : cutlass::gcd_cxx11(b, a % b); +} + +/** + * Least common multiple + */ +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t lcm_cxx11(value_t a, value_t b) { + return cutlass::gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / cutlass::gcd_cxx11(a, b) * + cutlass::abs_for_integer(b)) + : value_t{}; +} + +/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +int round_up(int a, int b) { + return ((a + b - 1) / b) * b; +} + +/// Returns the ceiling of (a / b) +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * log2 computation, what's the + * difference between the below codes and + * log2_up/down codes? + */ +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t clz(value_t x) { + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) + return value_t(31 - i); + } + return value_t(32); +} + +template +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +value_t find_log2(value_t x) { + int a = int(31 - clz(x)); + a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2. + return a; +} + + +/** + * Find divisor, using find_log2 + */ +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +void find_divisor(unsigned int& mul, unsigned int& shr, unsigned int denom) { + if (denom == 1) { + mul = 0; + shr = 0; + } else { + unsigned int p = 31 + find_log2(denom); + unsigned m = unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom)); + + mul = m; + shr = p - 32; + } +} + +/** + * Find quotient and remainder using device-side intrinsics + */ +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigned int shr) { + + #if defined(__CUDA_ARCH__) + // Use IMUL.HI if div != 1, else simply copy the source. + quo = (div != 1) ? __umulhi(src, mul) >> shr : src; + #else + quo = int((div != 1) ? int(((int64_t)src * mul) >> 32) >> shr : src); + #endif + + // The remainder. + rem = src - (quo * div); +} + +// For long int input +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 +void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, unsigned int shr) { + + #if defined(__CUDA_ARCH__) + // Use IMUL.HI if div != 1, else simply copy the source. + quo = (div != 1) ? __umulhi(src, mul) >> shr : src; + #else + quo = int((div != 1) ? ((src * mul) >> 32) >> shr : src); + #endif + // The remainder. + rem = src - (quo * div); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object to encapsulate the fast division+modulus operation. +/// +/// This object precomputes two values used to accelerate the computation and is best used +/// when the divisor is a grid-invariant. In this case, it may be computed in host code and +/// marshalled along other kernel arguments using the 'Params' pattern. +/// +/// Example: +/// +/// +/// int quotient, remainder, dividend, divisor; +/// +/// FastDivmod divmod(divisor); +/// +/// divmod(quotient, remainder, dividend); +/// +/// // quotient = (dividend / divisor) +/// // remainder = (dividend % divisor) +/// +struct FastDivmod { + using value_div_type = int; + using value_mod_type = int64_t; + int32_t divisor = 1; + uint32_t multiplier = 0u; + uint32_t shift_right = 0u; + + // Find quotient and remainder using device-side intrinsics + CUTLASS_HOST_DEVICE + void fast_divmod(int& quotient, int& remainder, int dividend) const { + +#if defined(__CUDA_ARCH__) + // Use IMUL.HI if divisor != 1, else simply copy the source. + quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; +#else + quotient = int((divisor != 1) ? int(((int64_t)dividend * multiplier) >> 32) >> shift_right : dividend); +#endif + + // The remainder. + remainder = dividend - (quotient * divisor); + } + + /// For long int input + CUTLASS_HOST_DEVICE + void fast_divmod(int& quotient, int64_t& remainder, int64_t dividend) const { + +#if defined(__CUDA_ARCH__) + // Use IMUL.HI if divisor != 1, else simply copy the source. + quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; +#else + quotient = int((divisor != 1) ? ((dividend * multiplier) >> 32) >> shift_right : dividend); +#endif + // The remainder. + remainder = dividend - (quotient * divisor); + } + + + /// Construct the FastDivmod object, in host code ideally. + /// + /// This precomputes some values based on the divisor and is computationally expensive. + + constexpr FastDivmod() = default; + + CUTLASS_HOST_DEVICE + FastDivmod(int divisor_): divisor(divisor_) { + assert(divisor_ >= 0); + if (divisor != 1) { + unsigned int p = 31 + find_log2(divisor); + unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor)); + + multiplier = m; + shift_right = p - 32; + } + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(int "ient, int &remainder, int dividend) const { + fast_divmod(quotient, remainder, dividend); + } + + /// Computes integer division using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + int div(int dividend) const { + int quotient, remainder; + fast_divmod(quotient, remainder, dividend); + return quotient; + } + + /// Alias for `div` to match the interface of FastDivmodU64 + CUTLASS_HOST_DEVICE + int divide(int dividend) const { + return div(dividend); + } + + /// Computes integer division remainder using precomputed values. + CUTLASS_HOST_DEVICE + int rem(int dividend) const { + int quotient, remainder; + fast_divmod(quotient, remainder, dividend); + return remainder; + } + + /// Alias for `rem` + CUTLASS_HOST_DEVICE + int remainder(int dividend) const { + return rem(dividend); + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + /// + /// Simply returns the quotient + CUTLASS_HOST_DEVICE + int divmod(int &remainder, int dividend) const { + int quotient; + fast_divmod(quotient, remainder, dividend); + return quotient; + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(int "ient, int64_t &remainder, int64_t dividend) const { + fast_divmod(quotient, remainder, dividend); + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + int divmod(int64_t &remainder, int64_t dividend) const { + int quotient; + fast_divmod(quotient, remainder, dividend); + return quotient; + } + + /// Returns the divisor when cast to integer + CUTLASS_HOST_DEVICE + operator int() const { return divisor; } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object to encapsulate the fast division+modulus operation for 64b integer division. +/// +/// This object precomputes two values used to accelerate the computation and is best used +/// when the divisor is a grid-invariant. In this case, it may be computed in host code and +/// marshalled along other kernel arguments using the 'Params' pattern. +/// +/// Example: +/// +/// +/// uint64_t quotient, remainder, dividend, divisor; +/// +/// FastDivmodU64 divmod(divisor); +/// +/// divmod(quotient, remainder, dividend); +/// +/// // quotient = (dividend / divisor) +/// // remainder = (dividend % divisor) +/// +struct FastDivmodU64 { + + uint64_t divisor; + uint64_t multiplier; + unsigned int shift_right; + unsigned int round_up; + + // + // Static methods + // + + /// Computes b, where 2^b is the greatest power of two that is less than or equal to x + CUTLASS_HOST_DEVICE + static uint32_t integer_log2(uint64_t x) { + uint32_t n = 0; + while (x >>= 1) { + ++n; + } + return n; + } + + /// Default ctor + CUTLASS_HOST_DEVICE + FastDivmodU64(): divisor(0), multiplier(0), shift_right(0), round_up(0) { } + + /// Construct the FastDivmod object, in host code ideally. + /// + /// This precomputes some values based on the divisor and is computationally expensive. + CUTLASS_HOST_DEVICE + FastDivmodU64(uint64_t divisor_): divisor(divisor_), multiplier(1), shift_right(0), round_up(0) { + + if (divisor) { + shift_right = integer_log2(divisor); + + if ((divisor & (divisor - 1)) == 0) { + multiplier = 0; + } + else { + uint64_t power_of_two = (uint64_t(1) << shift_right); + uint64_t multiplier_lo = uint128_t(0, power_of_two) / divisor; + multiplier = uint128_t(power_of_two, power_of_two) / divisor; + round_up = (multiplier_lo == multiplier ? 1 : 0); + } + } + } + + /// Returns the quotient of floor(dividend / divisor) + CUTLASS_HOST_DEVICE + uint64_t divide(uint64_t dividend) const { + uint64_t quotient = 0; + + #ifdef __CUDA_ARCH__ + uint64_t x = dividend; + if (multiplier) { + x = __umul64hi(dividend + round_up, multiplier); + } + quotient = (x >> shift_right); + #else + quotient = dividend / divisor; + #endif + + return quotient; + } + + /// Computes the remainder given a computed quotient and dividend + CUTLASS_HOST_DEVICE + uint64_t modulus(uint64_t quotient, uint64_t dividend) const { + return dividend - quotient * divisor; + } + + /// Returns the quotient of floor(dividend / divisor) and computes the remainder + CUTLASS_HOST_DEVICE + uint64_t divmod(uint64_t &remainder, uint64_t dividend) const { + uint64_t quotient = divide(dividend); + remainder = modulus(quotient, dividend); + return quotient; + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(uint64_t "ient, uint64_t &remainder, uint64_t dividend) const { + quotient = divmod(remainder, dividend); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object to encapsulate the fast division+modulus operation for 64b integer division +/// in which the divisor is a power of two. +struct FastDivmodU64Pow2 { + + uint64_t divisor; + unsigned int shift_right; + + /// Default ctor + CUTLASS_HOST_DEVICE + FastDivmodU64Pow2(): divisor(0), shift_right(0) { } + + /// Construct the FastDivmod object, in host code ideally. + /// + /// This precomputes some values based on the divisor and is computationally expensive. + CUTLASS_HOST_DEVICE + FastDivmodU64Pow2(uint64_t divisor_): divisor(divisor_), shift_right(FastDivmodU64::integer_log2(divisor_)) { } + + /// Returns the quotient of floor(dividend / divisor) + CUTLASS_HOST_DEVICE + uint64_t divide(uint64_t dividend) const { + return dividend >> shift_right; + } + + /// Computes the remainder given a computed quotient and dividend + CUTLASS_HOST_DEVICE + uint64_t modulus(uint64_t dividend) const { + // See https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#division-modulo-operations + return dividend & (divisor - 1); + } + + /// Returns the quotient of floor(dividend / divisor) and computes the remainder + CUTLASS_HOST_DEVICE + uint64_t divmod(uint64_t &remainder, uint64_t dividend) const { + uint64_t quotient = divide(dividend); + remainder = modulus(dividend); + return quotient; + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(uint64_t "ient, uint64_t &remainder, uint64_t dividend) const { + quotient = divmod(remainder, dividend); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes the coordinate decomposition from a linear index (64-bit linear index => coord) +/// +/// This decomposition is accelerated by the FastDivmodU64 object. It is assumed that +/// a coordinate of indices can be decomposed by div/mod operations. +/// Note, is assumed that element divmod[0] divides by extent[1]. +/// +/// For example, assume 4-D coordinate (n, p, q, c) is mapped to a linear index `npqc`. This +/// can be decomposed via three divide and modulus operations: +/// +/// c = npqc % C; | divmod[2] = FastDivmodU64(C) +/// npq = npqc / C; | coord[3] = c +/// +/// q = npq % Q; | divmod[1] = FastDivmodU64(Q) +/// np = npq / Q; | coord[2] = q +/// +/// p = np % P; | divmod[0] = FastDivmodU64(P) +/// n = np / P; | coord[1] = p +/// +/// | coord[0] = n +/// +template +CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( + uint64_t linear_idx, ///< Linear index to decompose + FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank; i > 1; --i) { + uint64_t remainder; + linear_idx = divmod[i - 2].divmod(remainder, linear_idx); + coord[i - 1] = int(remainder); + } + + coord[0] = int(linear_idx); + + return coord; +} + +/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) +template +CUTLASS_HOST_DEVICE Coord CoordinateDecomposition( + int linear_idx, ///< Linear index to decompose + FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank; i > 1; --i) { + int remainder; + linear_idx = divmod[i - 2].divmod(remainder, linear_idx); + coord[i - 1] = int(remainder); + } + + coord[0] = int(linear_idx); + + return coord; +} + +template +CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( + uint64_t linear_idx, ///< Linear index to decompose + FastDivmodU64 const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank - 1; ++i) { + uint64_t remainder; + linear_idx = divmod[i].divmod(remainder, linear_idx); + coord[i] = int(remainder); + } + + coord[Rank - 1] = int(linear_idx); + + return coord; +} + +/// Computes the coordinate decomposition from a linear index (32-bit linear index => coord) +template +CUTLASS_HOST_DEVICE Coord CoordinateDecompositionLittleEndian( + int linear_idx, ///< Linear index to decompose + FastDivmod const *divmod) { ///< Pointer to array of Rank-1 FastDivmodU64 objects + + static_assert(Rank > 0, "CoordinateDecomposition requires Rank=1 or greater."); + + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank - 1; ++i) { + int remainder; + linear_idx = divmod[i].divmod(remainder, linear_idx); + coord[i] = int(remainder); + } + + coord[Rank - 1] = int(linear_idx); + + return coord; +} + +/// Safely computes the offset of a linear index in bytes for all types +template +CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) { + + static_assert( + (sizeof_bits::value >= 8 && !(sizeof_bits::value % 8)) || + (sizeof_bits::value < 8 && !(8 % sizeof_bits::value)), + "Size of numeric type in bits must either be divisible by 8 bits, or 8 bits must be divisible by the size."); + + if (sizeof_bits::value >= 8) { + return index * (sizeof_bits::value / 8); + } + else { + int const kElementsPerByte = ((8 / sizeof_bits::value) + ((sizeof_bits::value >= 8) ? 1 : 0)); + return index / kElementsPerByte; + } +} + +CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index, int64_t element_sizeof_bits) { + if (element_sizeof_bits >= 8) { + return index * (element_sizeof_bits / 8); + } + else { + int64_t const kElementsPerByte = ((8 / element_sizeof_bits) + ((element_sizeof_bits >= 8) ? 1 : 0)); + return index / kElementsPerByte; + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Min/Max +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Min { + static int const kValue = (A < B) ? A : B; +}; + +template +struct Max { + static int const kValue = (A > B) ? A : B; +}; + +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 int const_min(int a, int b) { + return (b < a ? b : a); +} + +CUTLASS_HOST_DEVICE +CUTLASS_CONSTEXPR_IF_CXX17 int const_max(int a, int b) { + return (b > a ? b : a); +} + +template +CUTLASS_HOST_DEVICE +T fast_min(T a, T b) { + return (b < a ? b : a); +} + +template <> +CUTLASS_HOST_DEVICE +float fast_min(float a, float b) { + return fminf(a, b); +} + +template +CUTLASS_HOST_DEVICE +T fast_max(T a, T b) { + return (a < b ? b : a); +} + +template <> +CUTLASS_HOST_DEVICE +float fast_max(float a, float b) { + return fmaxf(a, b); +} + +CUTLASS_HOST_DEVICE +float fast_cos(float theta) { + #if defined(__CUDA_ARCH__) + return ::cosf(theta); + #else + return std::cos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_cos(double theta) { + #if defined(__CUDA_ARCH__) + return ::cos(theta); + #else + return std::cos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_sin(float theta) { + #if defined(__CUDA_ARCH__) + return ::sinf(theta); + #else + return std::sin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_sin(double theta) { + #if defined(__CUDA_ARCH__) + return ::sin(theta); + #else + return std::sin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_acos(float theta) { + #if defined(__CUDA_ARCH__) + return ::acosf(theta); + #else + return std::acos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_acos(double theta) { + #if defined(__CUDA_ARCH__) + return ::acos(theta); + #else + return std::acos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_asin(float theta) { + #if defined(__CUDA_ARCH__) + return ::asinf(theta); + #else + return std::asin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_asin(double theta) { + #if defined(__CUDA_ARCH__) + return ::asin(theta); + #else + return std::asin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_sqrt(float theta) { + #if defined(__CUDA_ARCH__) + return ::sqrtf(theta); + #else + return std::sqrt(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_sqrt(double theta) { + #if defined(__CUDA_ARCH__) + return ::sqrt(theta); + #else + return std::sqrt(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_exp(float x) { + #if defined(__CUDA_ARCH__) + return ::expf(x); + #else + return std::exp(x); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_exp(double x) { + #if defined(__CUDA_ARCH__) + return ::exp(x); + #else + return std::exp(x); + #endif +} + +CUTLASS_HOST_DEVICE +half_t fast_exp(half_t x) { + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) + return (half_t)(::hexp(x.to_half())); + #else + return (half_t)(fast_exp(float(x))); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_log(float x) { + #if defined(__CUDA_ARCH__) + return ::logf(x); + #else + return std::log(x); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_log(double x) { + #if defined(__CUDA_ARCH__) + return ::log(x); + #else + return std::log(x); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_tanh(float x) { + #if defined(__CUDA_ARCH__) + #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) + float y; + asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); + return y; + #else + return ::tanhf(x); + #endif + #else + return std::tanh(x); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_tanh(double x) { + #if defined(__CUDA_ARCH__) + return ::tanh(x); + #else + return std::tanh(x); + #endif +} + +CUTLASS_HOST_DEVICE +half_t fast_tanh(half_t x) { + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) + + asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(x.raw()) : "h"(x.raw())); + return x; + + #else + return half_t(fast_tanh(float(x))); + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct fast_exp_op { + CUTLASS_HOST_DEVICE + T operator()(T const &rhs) const { + return fast_exp(rhs); + } +}; + +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) +template +struct fast_exp_op> { + CUTLASS_DEVICE + Array operator()(Array const &rhs) const { + + Array result; + + // use x2 specialization + __half2 const *in = reinterpret_cast<__half2 const *>(&rhs); + __half2 *out = reinterpret_cast<__half2 *>(&result); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + out[i] = ::h2exp(in[i]); + } + + // residual + if (N % 2) { + half_t last = rhs[N - 1]; + result[N - 1] = half_t(::hexp(last.to_half())); + } + + return result; + } +}; +#endif // #if defined(__CUDA_ARCH__) + +template +struct fast_exp_op> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + + fast_exp_op fast_op; + Array y; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = fast_op(rhs[i]); + } + + return y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct fast_tanh_op { + CUTLASS_HOST_DEVICE + T operator()(T const &rhs) const { + return fast_tanh(rhs); + } +}; + +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) +template +struct fast_tanh_op> { + CUTLASS_DEVICE + Array operator()(Array const &rhs) const { + + Array result; + + // use x2 specialization + uint32_t const *in = reinterpret_cast(&rhs); + uint32_t *out = reinterpret_cast(&result); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm volatile ("tanh.approx.f16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i])); + } + + // residual + if (N % 2) { + uint16_t const *in = reinterpret_cast(&rhs); + uint16_t *out = reinterpret_cast(&result); + asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1])); + } + + return result; + } +}; +#endif // #if defined(__CUDA_ARCH__) + +template +struct fast_tanh_op> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + + fast_tanh_op fast_op; + Array y; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = fast_op(rhs[i]); + } + + return y; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Absolute value function +template +CUTLASS_HOST_DEVICE +T absolute_value(T x) { + if (x < T()) { + return -x; + } + return x; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/float8.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/float8.h new file mode 100644 index 0000000000000000000000000000000000000000..eab0b35f901198316b2f2416fd24bcd6c7d2af70 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/float8.h @@ -0,0 +1,1685 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a class for using IEEE half-precision floating-point types in host or + device code. +*/ + +#pragma once + + +#include "cutlass/arch/config.h" + + +// FP8 types are available starting CUDA 11.8+ +#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) +#define CUDA_FP8_ENABLED 1 +#endif + +#if defined(__CUDA_ARCH__) +# if (__CUDA_ARCH__ >= 900) +# if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) +# define CUDA_PTX_FP8_CVT_ENABLED 1 +# endif // (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) +# elif (__CUDA_ARCH__ == 890) +# if (__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1)) +# define CUDA_PTX_FP8_CVT_ENABLED 1 +# endif // (__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1)) +# endif // (__CUDA_ARCH__ >= 900) +#endif // defined(__CUDA_ARCH__) + + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUDA_PTX_UE8M0_CVT_ENABLED 1 +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUDA_PTX_UE8M0_CVT_ENABLED 1 +#endif + +#ifdef __GNUC__ +// Ignore checks on reinterpret-casts that are being used for bitcasts. +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDACC_RTC__) + +#include "cutlass/floating_point_nvrtc.h" + +#else +// +// Standard Library headers belong here to avoid conflicts with NVRTC. +// +#include +#include +#include +#include +#endif + +#ifdef CUDA_FP8_ENABLED +#include +#endif +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/exmy_base.h" + +#include "cute/util/type_traits.hpp" + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// FP8 Has 2 encodings possible : E4M3 and E5M2 +// +// E4M3 : 7 | 6 5 4 3 | 2 1 0 +// E5M2 : 7 | 6 5 4 3 2 | 1 0 +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class FloatEncoding { + E4M3, + E5M2 +}; + +template +struct alignas(1) float8_base { + + static constexpr bool IS_E4M3 = (T == FloatEncoding::E4M3); + static constexpr bool IS_E5M2 = (T == FloatEncoding::E5M2); + + // Number of Bits representing mantissa and exponents + static constexpr int FP32_NUM_BITS = 32; + static constexpr int FP32_NUM_EXPONENT_BITS = 8; + static constexpr int FP32_NUM_MANTISSA_BITS = 23; + static constexpr uint32_t FP32_NAN = 0x7fffffff; + static constexpr uint32_t FP32_INFINITY_MASK = 0x7f800000; + static constexpr int FP32_MAX_EXPONENT = 127; + static constexpr int FP32_MIN_EXPONENT = -126; + static constexpr int FP32_EXPONENT_BIAS = 127; + + static constexpr int FP16_NUM_BITS = 16; + static constexpr int FP16_NUM_EXPONENT_BITS = 5; + static constexpr int FP16_NUM_MANTISSA_BITS = 10; + static constexpr uint16_t FP16_NAN = 0x7fff; + static constexpr uint16_t FP16_INFINITY_MASK = 0x7c00; + static constexpr int FP16_MAX_EXPONENT = 15; + static constexpr int FP16_MIN_EXPONENT = -14; + static constexpr int FP16_EXPONENT_BIAS = 15; + + static constexpr int FP8_NUM_BITS = 8; + static constexpr int FP8_NUM_EXPONENT_BITS = IS_E4M3 ? 4 : 5; + static constexpr int FP8_NUM_MANTISSA_BITS = IS_E4M3 ? 3 : 2; + static constexpr uint8_t FP8_NAN = 0x7f; // Also F8_INF + static constexpr uint8_t FP8_INFINITY_MASK = IS_E4M3 ? 0x78 : 0x7c; + static constexpr int FP8_MAX_EXPONENT = IS_E4M3 ? 7 : 15; + static constexpr int FP8_MIN_EXPONENT = IS_E4M3 ? -6 : -14; + static constexpr int FP8_EXPONENT_BIAS = IS_E4M3 ? 7 : 15; + + static constexpr uint8_t FP8_EXPONENT_MASK = (1 << FP8_NUM_EXPONENT_BITS) - 1; + static constexpr uint8_t FP8_MANTISSA_MASK = (1 << FP8_NUM_MANTISSA_BITS) - 1; + + static constexpr uint8_t FP8_MAX_FLT = (IS_E4M3 ? 0x7e : 0x7b); + + // 256 in float + static constexpr uint32_t FP8_SAT_VAL_FP32 = 0x43800000; + + // + // Data members + // + + /// Data container + uint8_t storage; + + /// Ctors. + CUTLASS_HOST_DEVICE + float8_base() : storage(0) { } + + /// Is finite implementation + CUTLASS_HOST_DEVICE + static bool isfinite(float flt) { + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + return (s & 0x7f800000) < 0x7f800000; + } + + /// Is NaN implementation + CUTLASS_HOST_DEVICE + static bool isnan(float flt) { + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + return (s & 0x7fffffff) > 0x7f800000; + } + + /// Is infinite implementation + CUTLASS_HOST_DEVICE + static bool isinf(float flt) { + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + // Sign = 0 for +inf, 1 for -inf + // Exponent = all ones + // Mantissa = all zeros + return (s == 0x7f800000) || (s == 0xff800000); + } + + /// FP32 -> FP8 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static uint8_t convert_float_to_fp8(float const& flt) { + + // software implementation rounds toward nearest even + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + // Extract the bits in the FP32 type + uint8_t sign = uint8_t((s >> 24 & 0x80)); + int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS; + int mantissa = s & 0x7fffff; + uint8_t u = 0; + + uint8_t const kF8_NaN = 0x7f; + + // NaN => NaN + if (isnan(flt)) { + return kF8_NaN; + } + + // Inf => MAX_FLT (satfinite) + if (isinf(flt)) { + return sign | FP8_MAX_FLT; + } + + // Special handling + if (exp == -128) { + // int8 range is from -128 to 127 + // So 255(inf) - 127(bias) = 128 - will show up as -128 + + // satfinite + return (sign | FP8_MAX_FLT); + } + + int sticky_bit = 0; + + bool skip_sign = false; + bool may_be_nan = false; + + if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) { + // normal fp32 to normal fp8 + exp = exp + FP8_EXPONENT_BIAS; + u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS); + u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); + } else if(exp < FP8_MIN_EXPONENT) { + // normal single-precision to subnormal float8-precision representation + int rshift = (FP8_MIN_EXPONENT - exp); + if (rshift < FP32_NUM_BITS) { + mantissa |= (1 << FP32_NUM_MANTISSA_BITS); + + sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); + + mantissa = (mantissa >> rshift); + u = (uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS- FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK); + } else { + mantissa = 0; + u = 0; + } + // Exponent > FP8_MAX_EXPONENT - this is a special case done to match HW + // 0x4380_0000 to 0x43e0_0000 - maps from 256 to 448, and does not saturate / inf. + } else { + if( exp == (FP8_MAX_EXPONENT + 1) ) { + uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); + if( mantissa_tmp < FP8_MANTISSA_MASK) { + exp = exp + FP8_EXPONENT_BIAS; + u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; + may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1)); + } else { + // satfinite + return (sign | FP8_MAX_FLT); + } + } else{ + // satfinite + return (sign | FP8_MAX_FLT); + } + } + + // round to nearest even + int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1); + int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1); + sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { + u = uint8_t(u + 1); + if( may_be_nan ) { + skip_sign = true; + } + } + + if (u > FP8_MAX_FLT) { + // satfinite + u = (sign | FP8_MAX_FLT); + } + + if( ! skip_sign ) { + u |= sign; + } + + return u; + } + + + /// Converts a fp8 value stored as a uint8_t to a float + CUTLASS_HOST_DEVICE + static float convert_fp8_to_float(uint8_t const& x) { + + uint32_t constexpr kF32_NaN = 0x7fffffff; + + uint8_t const &f8 = x; + uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; + uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; + uint32_t mantissa = f8 & FP8_MANTISSA_MASK; + unsigned f = (sign << (FP32_NUM_BITS-1)); + + if (IS_E4M3 && exp == 15 && mantissa == 0x7) { + f = kF32_NaN; + } + else if (exp > 0 && (IS_E4M3 || exp < (FP8_MAX_EXPONENT + FP8_EXPONENT_BIAS + 1))) { + // normal + exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS); + f = f | + (exp << FP32_NUM_MANTISSA_BITS) | + (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); + } else if (exp == 0) { + if (mantissa) { + // subnormal + exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS) + 1; + while ((mantissa & (1 << FP8_NUM_MANTISSA_BITS)) == 0) { + mantissa <<= 1; + exp--; + } + mantissa &= FP8_MANTISSA_MASK; + f = f | + (exp << FP32_NUM_MANTISSA_BITS) | + (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); + } else { + // sign-preserving zero + } + } else { + if(mantissa == 0){ + // Sign-preserving infinity + f = (f | 0x7f800000); + } else { + // Canonical NaN + f = kF32_NaN; + } + } + + #if defined(__CUDA_ARCH__) + return reinterpret_cast(f); + #else + float flt; + std::memcpy(&flt, &f, sizeof(flt)); + return flt; + #endif + } +}; + + +// Forward declaration of float_e5m2_t to define float_e4m3_t <=> float_e5m2_t +// conversions in class float_e4m3_t +struct float_e5m2_t; + + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : E4M3 +/// +/////////////////////////////////////////////////////////////// +struct alignas(1) float_e4m3_t : float8_base { + + using Base = float8_base; + + static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; + + // + // Static conversion operators + // + + /// Constructs from an uint8_t + CUTLASS_HOST_DEVICE + static float_e4m3_t bitcast(uint8_t x) { + float_e4m3_t f; + f.storage = x; + return f; + } + + /// FP32 -> FP8 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e4m3_t from_float(float const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp; + float y = float(); + asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(flt)); + #endif + } + + /// FP16 -> E5M2 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e4m3_t from_half(half const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp = 0; + uint32_t bits = reinterpret_cast(flt); + asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(__half2float(flt))); + #endif + } + + // E4M3 -> half + CUTLASS_HOST_DEVICE + static half to_half(float_e4m3_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return reinterpret_cast(packed).x; + #else + return __float2half(Base::convert_fp8_to_float(x.storage)); + #endif + } + + // E4M3 -> Float + CUTLASS_HOST_DEVICE + static float to_float(float_e4m3_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return __half2float(reinterpret_cast(packed).x); + #else + return Base::convert_fp8_to_float(x.storage); + #endif + } + + // + // Methods + // + + /// Constructor inheritance + using Base::Base; + + /// Default constructor + float_e4m3_t() = default; + +#ifdef CUDA_FP8_ENABLED + /// Conversion from CUDA's FP8 type + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(__nv_fp8_e4m3 x) { + storage = x.__x; + } +#endif + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(float x) { + storage = from_float(x).storage; + } + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(half x) { + storage = from_half(x).storage; + } + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(double x): float_e4m3_t(float(x)) { + } + + /// Integer conversion + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(int x): float_e4m3_t(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(unsigned x): float_e4m3_t(float(x)) { + } + + /// E5M2 conversion. Defined after float_e5m2_t is defined. + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(float_e5m2_t x); + +#ifdef CUDA_FP8_ENABLED + /// Assignment from CUDA's FP8 type + CUTLASS_HOST_DEVICE + float_e4m3_t & operator=(__nv_fp8_e4m3 x) { + storage = x.__x; + return *this; + } +#endif + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + return to_float(*this); + } + + /// Converts to half + CUTLASS_HOST_DEVICE + operator half() const { + return to_half(*this); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(to_float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + #if defined(__CUDA_ARCH__) + return __half2int_rn(to_half(*this)); + #else + return int(to_float(*this)); + #endif + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + #if defined(__CUDA_ARCH__) + return bool(__half2int_rn(to_half(*this))); + #else + return bool(int(to_float(*this))); + #endif + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t& raw() { + return storage; + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 15; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(storage & Base::FP8_MANTISSA_MASK); + } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_e4m3_t const& x) { + return x.storage == uint8_t(0x7f); + } + +}; +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : E5M2 +/// +/////////////////////////////////////////////////////////////// +struct alignas(1) float_e5m2_t : float8_base { + + using Base = float8_base; + + static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; + + // + // Static conversion operators + // + + /// Constructs from an uint8_t + CUTLASS_HOST_DEVICE + static float_e5m2_t bitcast(uint8_t x) { + float_e5m2_t f; + f.storage = x; + return f; + } + + /// FP32 -> FP8 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e5m2_t from_float(float const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp; + float y = float(); + asm volatile("cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(flt)); + #endif + } + + /// FP16 -> E5M2 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e5m2_t from_half(half const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp = 0; + uint32_t bits = reinterpret_cast(flt); + asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(__half2float(flt))); + #endif + } + + // E5M2 -> half + CUTLASS_HOST_DEVICE + static half to_half(float_e5m2_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return reinterpret_cast(packed).x; + #else + return __float2half(Base::convert_fp8_to_float(x.storage)); + #endif + } + + // E5M2 -> Float + CUTLASS_HOST_DEVICE + static float to_float(float_e5m2_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return __half2float(reinterpret_cast(packed).x); + #else + return Base::convert_fp8_to_float(x.storage); + #endif + } + + // + // Methods + // + + /// Constructor inheritance + using Base::Base; + + /// Default constructor + float_e5m2_t() = default; + +#ifdef CUDA_FP8_ENABLED + /// Conversion from CUDA's FP8 type + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(__nv_fp8_e5m2 x) { + storage = x.__x; + } +#endif + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(float x) { + storage = from_float(x).storage; + } + + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(half x) { + storage = from_half(x).storage; + } + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(double x): float_e5m2_t(float(x)) { + } + + /// Integer conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(int x): float_e5m2_t(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(unsigned x): float_e5m2_t(float(x)) { + } + + /// E4M3 conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(float_e4m3_t x); + +#ifdef CUDA_FP8_ENABLED + /// Assignment from CUDA's FP8 type + CUTLASS_HOST_DEVICE + float_e5m2_t & operator=(__nv_fp8_e5m2 x) { + storage = x.__x; + return *this; + } +#endif + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + return to_float(*this); + } + + /// Converts to half + CUTLASS_HOST_DEVICE + operator half() const { + return to_half(*this); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(to_float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + #if defined(__CUDA_ARCH__) + return __half2int_rn(to_half(*this)); + #else + return int(to_float(*this)); + #endif + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + #if defined(__CUDA_ARCH__) + return bool(__half2int_rn(to_half(*this))); + #else + return bool(int(to_float(*this))); + #endif + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t& raw() { + return storage; + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 15; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(storage & Base::FP8_MANTISSA_MASK); + } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_e5m2_t const& x) { + return x.storage == uint8_t(0x7f); + } + +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator+(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) + float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator-(float_e4m3_t const& lhs) { + return float_e4m3_t(-float(lhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator-(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator*(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator/(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator+=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator-=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator*=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator/=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator++(float_e4m3_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = float_e4m3_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator--(float_e4m3_t & lhs) { + float tmp(lhs); + --tmp; + lhs = float_e4m3_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator++(float_e4m3_t & lhs, int) { + float_e4m3_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = float_e4m3_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator--(float_e4m3_t & lhs, int) { + float_e4m3_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = float_e4m3_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +bool operator==(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator+(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) + float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator-(float_e5m2_t const& lhs) { + return float_e5m2_t(-float(lhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator-(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator*(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator/(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator+=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator-=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator*=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator/=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator++(float_e5m2_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = float_e5m2_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator--(float_e5m2_t & lhs) { + float tmp(lhs); + --tmp; + lhs = float_e5m2_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator++(float_e5m2_t & lhs, int) { + float_e5m2_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = float_e5m2_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator--(float_e5m2_t & lhs, int) { + float_e5m2_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = float_e5m2_t(tmp); + return ret; +} + + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : UE4M3 +/// +/////////////////////////////////////////////////////////////// +// UE4M3: +// 4 Exponent bits, 3 Mantissa bits +// Range: [0:448] +// has_inf: false +// has_NaN: true +// has_denorm: true +// Exponent bias (exp_bias): 7 +struct float_ue4m3_t : public float_exmy_base { + using Base = float_exmy_base; + + float_ue4m3_t() = default; + + CUTLASS_HOST_DEVICE + float_ue4m3_t convert_from_float(float const &flt) const { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp; + float y = float(); + asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); + return bitcast(*reinterpret_cast(&tmp)); + #else + Base::FP32BitRepresentation::Storage fp32_bits = Base::FP32BitRepresentation::to_bits(flt); + return bitcast(BitRepresentation::convert_from(fp32_bits, Base::FP32BitRepresentation{})); + #endif + } + + CUTLASS_HOST_DEVICE + float convert_to_float(float_ue4m3_t const &x) const { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + return __half2float(reinterpret_cast(packed).x); + #else + Base::FP32BitRepresentation::Storage fp32_bits; + fp32_bits = Base::BitRepresentation::convert_to(x.storage, Base::FP32BitRepresentation{}); + return detail::copy_bits(fp32_bits); + #endif + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue4m3_t(unsigned x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_ue4m3_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_ue4m3_t const& x) { + return x.storage == uint8_t(0x7f); + } + +}; + +/// Defines the size of an element in bits - specialized for float_ue4m3_t +template <> +struct sizeof_bits { + static constexpr int value = sizeof_bits>::value; +}; + + + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : UE8M0 +/// +/////////////////////////////////////////////////////////////// +// UE8M0: +// 8 Exponent bits, 0 Mantissa bits +// Range: [2^-127:2^127] +// has_inf: false +// has_NaN: true (11111111) +// has_denorm: true +// Exponent bias (exp_bias): 8 + +struct float_ue8m0_t : public float_exmy_base { + using Base = float_exmy_base; + using FP32Bits = typename Base::FP32BitRepresentation; + + float_ue8m0_t() = default; + + CUTLASS_HOST_DEVICE + float_ue8m0_t convert_from_float(float const &flt) const { + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint16_t out; + asm volatile( + "{ cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1; }" + : "=h"(out) : "f"(flt)); + return bitcast(*reinterpret_cast(&out)); + #else + if (CUTLASS_CMATH_NAMESPACE::isnan(flt) || CUTLASS_CMATH_NAMESPACE::isinf(flt)) { + return bitcast(0xFF); + } + uint32_t flt_uint32 = cutlass::detail::copy_bits(flt); + uint8_t exp = (flt_uint32 >> 23) & 0xff; // Extract the 8 bit exponent + uint32_t mant = flt_uint32 & 0x7fffff; // Extract the 23 bit mantissa + // Do the round up + // Deals w/ satfinite all at once + if ((mant > 0) && (exp != 0xFE) && !(exp == 0 && mant <= 0x00400000)) { + exp++; + } + return bitcast(exp); + #endif + } + + CUTLASS_HOST_DEVICE + float convert_to_float(float_ue8m0_t const &x) const { + ////////////////////////////////////////////////////////////// + // The conversion of UE8M0 to FP32 scale can be done simply + // with a left shift (No rounding necessary) + // Note: The base class implements ue8m0 to FP32 based on the rules of float math conversions. + // The result of current implementation and base class are aligned. + ////////////////////////////////////////////////////////////// + #if defined(CUDA_PTX_UE8M0_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t bf16x2_val; + // E8 -> BF16 + asm volatile( + "{\n" + "cvt.rn.bf16x2.ue8m0x2 %0, %1;\n" + "}\n" : "=r"(bf16x2_val): "h"(bits)); + // BF16 -> FP32 + float f1; + asm( + "{\n" + "prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=f"(f1) + : "r"(0), "r"(bf16x2_val), "r"(0x5410)); + return f1; + #else + using FP32Bits = cutlass::detail::FpBitRepresentation; + if (x.storage == 0x00) { + return cutlass::detail::copy_bits(0x00400000); + } + else if (x.storage == 0xFF) { + return cutlass::detail::copy_bits(0x7fffffff); + } + else { + auto f8 = static_cast(x.storage); + FP32Bits::Storage f = (f8 << FP32Bits::NUM_MANTISSA_BITS); + return cutlass::detail::copy_bits(f); + } + #endif + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_ue8m0_t(unsigned x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_ue8m0_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + friend bool isnan(float_ue8m0_t const& x) { + return x.storage == uint8_t(0xff); + } + +}; + +/// Defines the size of an element in bits - specialized for float_ue8m0_t +template <> +struct sizeof_bits { + static constexpr int value = sizeof_bits>::value; +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// float_e4m3_t <=> float_e5m2_t conversions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// float_e4m3_t <= float_e5m2_t +CUTLASS_HOST_DEVICE +float_e4m3_t::float_e4m3_t(float_e5m2_t x) { + storage = from_float(float_e5m2_t::to_float(x)).storage; +} + +/// float_e5m2_t <= float_e4m3_t +CUTLASS_HOST_DEVICE +float_e5m2_t::float_e5m2_t(float_e4m3_t x) { + storage = from_float(float_e4m3_t::to_float(x)).storage; +} + +/////////////////////////////////////////////////////////////// +/// +/// Umbrella floating-point 8-bit data type : type_erased_dynamic_float8_t +/// This umbrella datatype can be enabled when a user provides a specific +/// datatype in runtime argument list. +/// +/// Currently supported runtime datatypes compatible with type_erased_dynamic_float8_t: +/// MXF8F6F4Format::E5M2 +/// MXF8F6F4Format::E4M3 +/// +/////////////////////////////////////////////////////////////// + +union type_erased_dynamic_float8_t { + uint8_t data; + cutlass::float_e5m2_t e5m2; + cutlass::float_e4m3_t e4m3; + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e5m2_t() const { + return e5m2; + } + + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e4m3_t() const { + return e4m3; + } + +}; + + + +/////////////////////////////////////////////////////////////// +/// MX type for float8 +/// Intended to be used in builders +/////////////////////////////////////////////////////////////// + +template +struct mx_float8_t { + static_assert(cute::is_same_v + || cute::is_same_v + || cute::is_same_v + , "Only float_e5m2_t, float_e4m3_t can have scale factors for MXFP8"); + using ScaleFactorType = cutlass::float_ue8m0_t; + using DataType = F8Type; +}; + +using type_erased_dynamic_mx_float8_t = mx_float8_t; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +namespace std { + +/// Numeric limits common to all float8 types +template +struct float8_base_numeric_limits { +private: + using F8Type = T; +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; + + /// Least positive value + CUTLASS_HOST_DEVICE + static F8Type min() { return F8Type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } + + /// Returns maximum rounding error + CUTLASS_HOST_DEVICE + static F8Type round_error() { return F8Type(0.5f); } + + /// Returns positive infinity value + CUTLASS_HOST_DEVICE + static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } + + /// Returns quiet NaN value + CUTLASS_HOST_DEVICE + static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns signaling NaN value + CUTLASS_HOST_DEVICE + static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE + static F8Type denorm_min() { return F8Type::bitcast(0x01); } +}; + +/// Numeric limits for float_e4m3_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = false; + + /// Minimum finite value + static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value + static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } +}; + +/// Numeric limits for float_e5m2_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = true; + + /// Minimum finite value + static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value + static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } +}; + + +template +struct float8_exmy_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + CUTLASS_HOST_DEVICE + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + CUTLASS_HOST_DEVICE + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + CUTLASS_HOST_DEVICE + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + CUTLASS_HOST_DEVICE + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + CUTLASS_HOST_DEVICE + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE + static type denorm_min() { return type::bitcast(0x01); } +}; + +/// Numeric limits for float_ue8m0_t +template <> +struct numeric_limits : + public float8_exmy_numeric_limits { + static bool const has_infinity = false; + static bool const is_signed = false; + + /// Minimum finite value + static cutlass::float_ue8m0_t lowest() { return cutlass::float_ue8m0_t::bitcast(0xfe); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value (2^0) + static cutlass::float_ue8m0_t epsilon() { return cutlass::float_ue8m0_t::bitcast(0x7f); } +}; + + +} // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Numeric limits common to all float8 types +template +struct float8_base_numeric_limits { +private: + using F8Type = T; +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; + + /// Least positive value + CUTLASS_HOST_DEVICE + static F8Type min() { return F8Type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } + + /// Returns maximum rounding error + CUTLASS_HOST_DEVICE + static F8Type round_error() { return F8Type(0.5f); } + + /// Returns positive infinity value + CUTLASS_HOST_DEVICE + static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } + + /// Returns quiet NaN value + CUTLASS_HOST_DEVICE + static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns signaling NaN value + CUTLASS_HOST_DEVICE + static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE + static F8Type denorm_min() { return F8Type::bitcast(0x01); } +}; + +/// Forward Declaration +template +struct numeric_limits; + +/// Numeric limits for float_e4m3_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = false; + + /// Minimum finite value + static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value + static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } +}; + +/// Numeric limits for float_e5m2_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = true; + + /// Minimum finite value + static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value + static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } +}; + + +template +struct float8_exmy_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + CUTLASS_HOST_DEVICE + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + CUTLASS_HOST_DEVICE + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + CUTLASS_HOST_DEVICE + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + CUTLASS_HOST_DEVICE + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + CUTLASS_HOST_DEVICE + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + CUTLASS_HOST_DEVICE + static type denorm_min() { return type::bitcast(0x01); } +}; + +/// Numeric limits for float_ue8m0_t +template <> +struct numeric_limits : + public float8_exmy_numeric_limits { + static bool const has_infinity = false; + static bool const is_signed = false; + + /// Minimum finite value + static cutlass::float_ue8m0_t lowest() { return cutlass::float_ue8m0_t::bitcast(0xfe); } + + /// Machine epsilon, that is, the difference between 1.0 and the next representable value (2^0) + static cutlass::float_ue8m0_t epsilon() { return cutlass::float_ue8m0_t::bitcast(0x7f); } +}; + + +} // namespace platform + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::float_e4m3_t operator "" _fe4m3(long double x) { + return cutlass::float_e4m3_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { + return cutlass::float_e4m3_t(int(x)); +} + + +CUTLASS_HOST_DEVICE +cutlass::float_ue4m3_t operator "" _fue4m3(long double x) { + return cutlass::float_ue4m3_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_ue4m3_t operator "" _fue4m3(unsigned long long int x) { + return cutlass::float_ue4m3_t(int(x)); +} + + +CUTLASS_HOST_DEVICE +cutlass::float_e5m2_t operator "" _fe5m2(long double x) { + return cutlass::float_e5m2_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { + return cutlass::float_e5m2_t(int(x)); +} + + +CUTLASS_HOST_DEVICE +cutlass::float_ue8m0_t operator "" _fue8m0(long double x) +{ + return cutlass::float_ue8m0_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_ue8m0_t operator "" _fue8m0(unsigned long long int x) +{ + return cutlass::float_ue8m0_t(int(x)); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/float_subbyte.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/float_subbyte.h new file mode 100644 index 0000000000000000000000000000000000000000..eefab027291f6dcbec5dc795b2cf8f50b1728d4e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/float_subbyte.h @@ -0,0 +1,797 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! + \file + \brief Defines classes for FP4/FP6 datatypes +*/ +#pragma once + +#include "cutlass/arch/config.h" +#include "cutlass/float8.h" + +// FP4 types are available starting CUDA 12+ +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUDA_FP4_ENABLED 1 +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUDA_PTX_FP4FP6_CVT_ENABLED 1 +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUDA_PTX_FP4FP6_CVT_ENABLED 1 +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/exmy_base.h" + +#include "cute/util/type_traits.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +// FP4 and FP6 types +struct float_e2m1_t; +struct float_e3m2_t; +// E2M1: +// 2 Exponent bits with 1 Mantissa bit +// Range: +-[0,0.5,1,1.5,2,3,4,5,6] +// has_Inf: false +// has_NaN: false +// has_denorm: true +// Exponent bias (exp_bias): 1 + +struct float_e2m1_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m1_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m1_t(Base x) : Base(x) { + } +}; + +namespace detail { + +// This new type is used to select correct MMA type and TMA type. +struct float_e2m1_unpacksmem_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m1_unpacksmem_t() = default; + + CUTLASS_HOST_DEVICE + float_e2m1_unpacksmem_t(float_e2m1_unpacksmem_t const& x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_unpacksmem_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_unpacksmem_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m1_unpacksmem_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m1_unpacksmem_t(Base x) : Base(x) { + } +}; + +} // namespace detail + +/// Defines the size of an element in bits - specialized for float_e2m1_t +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + +CUTLASS_HOST_DEVICE +float_e2m1_t abs(float_e2m1_t const& val) { + using BaseType = typename float_e2m1_t::Base; + return float_e2m1_t(abs(BaseType{val.raw()})); +} + + +// E2M3: +// 2 Exponent bits with 3 Mantissa bit +// Range: [-7.5,+7.5] +// has_Inf: false +// has_NaN: false +// has_denorm: true +// Exponent bias (exp_bias): 1 + +struct float_e2m3_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m3_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m3_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_t(float_e3m2_t x); +}; + +namespace detail { + +struct float_e2m3_unpack8bits_t: public float_exmy_base { + // Used in register. + using Base = float_exmy_base; + + float_e2m3_unpack8bits_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpack8bits_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpack8bits_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpack8bits_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m3_unpack8bits_t(Base x) : Base(x) { + } +}; + +// This new type is used to select correct MMA type and TMA type. +struct float_e2m3_unpacksmem_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e2m3_unpacksmem_t() = default; + + CUTLASS_HOST_DEVICE + float_e2m3_unpacksmem_t(float_e2m3_unpacksmem_t const& x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpacksmem_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpacksmem_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e2m3_unpacksmem_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e2m3_unpacksmem_t(Base x) : Base(x) { + } +}; + +} // namespace detail + +/// Defines the size of an element in bits - specialized for float_e2m3_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +/// Defines the size of an element in bits - specialized for float_e2m3_unpacksmem_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +CUTLASS_HOST_DEVICE +float_e2m3_t abs(float_e2m3_t const& val) { + using BaseType = typename float_e2m3_t::Base; + return float_e2m3_t(abs(BaseType{val.raw()})); +} + +// E3M2: +// 3 Exponent bits, 2 Mantissa bits +// Range: [-28:+28] +// has_inf: false +// has_NaN: false +// has_denorm: true +// Exponent bias (exp_bias): 3 + +struct float_e3m2_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e3m2_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e3m2_t(Base x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_t(float_e2m3_t x); +}; + +namespace detail { + +struct float_e3m2_unpack8bits_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e3m2_unpack8bits_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpack8bits_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpack8bits_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpack8bits_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e3m2_unpack8bits_t(Base x) : Base(x) { + } +}; + +// This new type is used to select correct MMA type and TMA type. +struct float_e3m2_unpacksmem_t : public float_exmy_base { + + using Base = float_exmy_base; + + float_e3m2_unpacksmem_t() = default; + + CUTLASS_HOST_DEVICE + float_e3m2_unpacksmem_t(float_e3m2_unpacksmem_t const& x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpacksmem_t(double x) : Base(float(x)) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpacksmem_t(float x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + explicit float_e3m2_unpacksmem_t(int x) : Base(x) { + } + + CUTLASS_HOST_DEVICE + float_e3m2_unpacksmem_t(Base x) : Base(x) { + } +}; + +} // namespace detail + +/// Defines the size of an element in bits - specialized for float_e3m2_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +/// Defines the size of an element in bits - specialized for float_e3m2_unpacksmem_t +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +CUTLASS_HOST_DEVICE +float_e3m2_t abs(float_e3m2_t const& val) { + using BaseType = typename float_e3m2_t::Base; + return float_e3m2_t(abs(BaseType{val.raw()})); +} + +/// Defines the size of an element in bits - specialized for float_e3m2_unpack8bits_t +template <> +struct sizeof_bits { + static constexpr int value = 8; +}; + +/// Defines the size of an element in bits - specialized for float_e2m3_unpack8bits_t +template <> +struct sizeof_bits { + static constexpr int value = 8; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Get the register type used in kernel +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct get_unpacked_element_type; + +template <> +struct get_unpacked_element_type { + using type = detail::float_e2m3_unpack8bits_t; +}; + +template <> +struct get_unpacked_element_type { + using type = detail::float_e3m2_unpack8bits_t; +}; +} // namespace detail +// /////////////////////////////////////////////////////////////////////////////////////////////////// +// // +// // float_e2m3_t <=> float_e3m2_t conversions +// // +// /////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +float_e2m3_t::float_e2m3_t(float_e3m2_t x) +{ + storage = convert_from_float(float(x)).storage; +} + +CUTLASS_HOST_DEVICE +float_e3m2_t::float_e3m2_t(float_e2m3_t x) +{ + storage = convert_from_float(float(x)).storage; +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////// +/// +/// Umbrella floating-point 6-bit data type : type_erased_dynamic_float6_t +/// This umbrella datatype can be enabled when a user provides a specific +/// datatype in runtime argument list. +/// +/// Currently supported runtime datatypes compatible with type_erased_dynamic_float6_t: +/// MXF8F6F4Format::E2M3 +/// MXF8F6F4Format::E3M2 +/// +/////////////////////////////////////////////////////////////// + +union type_erased_dynamic_float6_t { + cutlass::float_e2m3_t e2m3; + cutlass::float_e3m2_t e3m2; + + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e2m3_t() const { + return e2m3; + } + + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e3m2_t() const { + return e3m2; + } +}; + +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + +/////////////////////////////////////////////////////////////// +/// +/// Umbrella floating-point 4-bit data type : type_erased_dynamic_float4_t +/// This umbrella datatype can be enabled when a user provides a specific +/// datatype in runtime argument list. +/// +/// Currently supported runtime datatypes compatible with type_erased_dynamic_float4_t: +/// MXF8F6F4Format::E2M1 +/// +/////////////////////////////////////////////////////////////// + +union type_erased_dynamic_float4_t { + cutlass::float_e2m1_t e2m1; + CUTLASS_HOST_DEVICE + explicit operator cutlass::float_e2m1_t() const { + return e2m1; + } +}; + +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + + +/////////////////////////////////////////////////////////////// +/// MX/NV types for float6 and float4 +/// Intended to be used in builders +/////////////////////////////////////////////////////////////// + +template +struct mx_float6_t +{ + static_assert(cute::is_same_v + || cute::is_same_v + || cute::is_same_v + , "Only float_e2m3_t, float_e3m2_t can have scale factors for MXFP6"); + using ScaleFactorType = cutlass::float_ue8m0_t; + using DataType = F6Type; +}; + +using type_erased_dynamic_mx_float6_t = mx_float6_t; + +template +struct mx_float4_t +{ + static_assert(cute::is_same_v + || cute::is_same_v + , "Only float_e2m1_t type_erased_dynamic_float4_t can have scale factors for MXFP4"); + using ScaleFactorType = cutlass::float_ue8m0_t; + using DataType = F4Type; +}; + +using type_erased_dynamic_mx_float4_t = mx_float4_t; + +template +struct nv_float4_t +{ + static_assert(cute::is_same_v + || cute::is_same_v + , "Only float_e2m1_t type_erased_dynamic_float4_t can have scale factors for NVFP4"); + using ScaleFactorType = cutlass::float_ue4m3_t; + using DataType = F4Type; +}; + +using type_erased_dynamic_nv_float4_t = nv_float4_t; + + +namespace detail { + +union type_erased_dynamic_float6_unpacksmem_t { + cutlass::detail::float_e2m3_unpacksmem_t e2m3_unpacksmem; + cutlass::detail::float_e3m2_unpacksmem_t e3m2_unpacksmem; + + CUTLASS_HOST_DEVICE + explicit operator cutlass::detail::float_e2m3_unpacksmem_t() const { + return e2m3_unpacksmem; + } + + CUTLASS_HOST_DEVICE + explicit operator cutlass::detail::float_e3m2_unpacksmem_t() const { + return e3m2_unpacksmem; + } +}; + +union type_erased_dynamic_float4_unpacksmem_t { + cutlass::detail::float_e2m1_unpacksmem_t e2m1_unpacksmem; + + CUTLASS_HOST_DEVICE + explicit operator cutlass::detail::float_e2m1_unpacksmem_t() const { + return e2m1_unpacksmem; + } +}; + +}; + +template <> +struct sizeof_bits { + static constexpr int value = 6; +}; + + +template <> +struct sizeof_bits { + static constexpr int value = 4; +}; + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) +namespace std { +/// Numeric limits common to all float4 types +template +struct float_subbyte_base_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = false; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + static type denorm_min() { return type::bitcast(0x01); } +}; +/// Numeric limits for float_e2m1_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m1_t lowest() { return cutlass::float_e2m1_t::bitcast(0xf); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m1_t epsilon() { return cutlass::float_e2m1_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e2m3_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m3_t lowest() { return cutlass::float_e2m3_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m3_t epsilon() { return cutlass::float_e2m3_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e3m2_t + +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e3m2_t lowest() { return cutlass::float_e3m2_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e3m2_t epsilon() { return cutlass::float_e3m2_t::bitcast(0x4); } +}; +} // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Numeric limits common to all float4 types +template +struct float_subbyte_base_numeric_limits +{ +private: + using type = T; + +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = false; + static bool const has_signaling_NaN = false; + static bool const has_denorm_loss = true; + static cutlass::platform::float_denorm_style const has_denorm = cutlass::platform::denorm_present; + static cutlass::platform::float_round_style const round_style = cutlass::platform::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = type::Base::BitRepresentation::NUM_MANTISSA_BITS; + static bool const has_infinity = false; + + /// Least positive value + static type min() { return type::bitcast(0x01); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE static type max() { return type::bitcast(type::Base::BitRepresentation::MAX_VALUE); } + + /// Returns maximum rounding error + static type round_error() { return type(0.5f); } + + /// Returns positive infinity value + static type infinity() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns quiet NaN value + static type quiet_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns signaling NaN value + static type signaling_NaN() { return type::bitcast(type::Base::BitRepresentation::INF_MASK); } + + /// Returns smallest positive subnormal value + static type denorm_min() { return type::bitcast(0x01); } +}; + +/// Forward Declaration +template +struct numeric_limits; +/// Numeric limits for float_e2m1_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m1_t lowest() { return cutlass::float_e2m1_t::bitcast(0xf); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m1_t epsilon() { return cutlass::float_e2m1_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e2m3_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e2m3_t lowest() { return cutlass::float_e2m3_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e2m3_t epsilon() { return cutlass::float_e2m3_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e3m2_t + +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::float_e3m2_t lowest() { return cutlass::float_e3m2_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::float_e3m2_t epsilon() { return cutlass::float_e3m2_t::bitcast(0x4); } +}; + +/// Numeric limits for float_e2m3_unpack8bits_t +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::detail::float_e2m3_unpack8bits_t lowest() { return cutlass::detail::float_e2m3_unpack8bits_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::detail::float_e2m3_unpack8bits_t epsilon() { return cutlass::detail::float_e2m3_unpack8bits_t::bitcast(0x1); } +}; + +/// Numeric limits for float_e3m2_unpack8bits_t + +template <> +struct numeric_limits : public float_subbyte_base_numeric_limits +{ + /// Minimum finite value + static cutlass::detail::float_e3m2_unpack8bits_t lowest() { return cutlass::detail::float_e3m2_unpack8bits_t::bitcast(0x2f); } + + /// Returns machine epsilon, that is, the difference between 1.0 and the next value representable by the floating-point + static cutlass::detail::float_e3m2_unpack8bits_t epsilon() { return cutlass::detail::float_e3m2_unpack8bits_t::bitcast(0x4); } +}; +} // namespace platform + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// +CUTLASS_HOST_DEVICE +cutlass::float_e2m1_t operator"" _fe2m1(long double x) +{ + return cutlass::float_e2m1_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e2m1_t operator"" _fe2m1(unsigned long long int x) +{ + return cutlass::float_e2m1_t(int(x)); +} +CUTLASS_HOST_DEVICE +cutlass::float_e2m3_t operator"" _fe2m3(long double x) +{ + return cutlass::float_e2m3_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e2m3_t operator"" _fe2m3(unsigned long long int x) +{ + return cutlass::float_e2m3_t(int(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e3m2_t operator"" _fe3m2(long double x) +{ + return cutlass::float_e3m2_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e3m2_t operator"" _fe3m2(unsigned long long int x) +{ + return cutlass::float_e3m2_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/floating_point_nvrtc.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/floating_point_nvrtc.h new file mode 100644 index 0000000000000000000000000000000000000000..6496fea077d59e0c0f7dfbf946534416c2189ca9 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/floating_point_nvrtc.h @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines categories for floating point numbers for use in NVRTC-compiled code +*/ + +#pragma once + +#include // CUTLASS_HOST_DEVICE +#include // uint32_t +#if !defined(__CUDACC_RTC__) +#include // std::memcpy +#endif + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// All floating-point numbers can be put in one of these categories. +enum { + FP_NAN = +# define FP_NAN 0 + FP_NAN, + FP_INFINITE = +# define FP_INFINITE 1 + FP_INFINITE, + FP_ZERO = +# define FP_ZERO 2 + FP_ZERO, + FP_SUBNORMAL = +# define FP_SUBNORMAL 3 + FP_SUBNORMAL, + FP_NORMAL = +# define FP_NORMAL 4 + FP_NORMAL +}; + +CUTLASS_HOST_DEVICE +int fpclassify(float const& f) { + + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(f); + #else + std::memcpy(&s, &f, sizeof(s)); + #endif + + uint32_t exp = s & 0x7f800000; + uint32_t mantissa = s & 0x007fffff; + + if (exp == 0x7f800000) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/functional.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..636cb8ca8a388430acdf1678f45045ab1805f9b6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/functional.h @@ -0,0 +1,1106 @@ + /*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Define basic numeric operators + + This is inspired by the Standard Library's header. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#endif + +#include + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include +#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) + +#ifdef _MSC_VER +// Provides support for alternate operators such as 'and', 'or', ... +#include +#include +#endif // _MSC_VER + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) +# define CUTLASS_ARCH_CREDUX_ENABLED +#endif + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + + CUTLASS_HOST_DEVICE int32_t popcount(int32_t x) { + #if defined(__CUDA_ARCH__) + return __popc(x); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcount(x); + #elif (defined(_MSC_VER) && !defined(_M_ARM64)) + return __popcnt(x); + #else + int32_t count = 0; + while (x) { + count += x & 1; + x >>= 1; + } + return count; + #endif + } + + CUTLASS_HOST_DEVICE int64_t popcount(int64_t x) { + #if defined(__CUDA_ARCH__) + return __popcll(x); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcountll(x); + #elif (defined(_MSC_VER) && !defined(_M_ARM64)) + return __popcnt64(x); + #else + int64_t count = 0; + while (x) { + count += x & 1; + x >>= 1; + } + return count; + #endif + } + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct absolute_value_op { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { + return abs(lhs); + } +}; + +template <> +struct absolute_value_op { + CUTLASS_HOST_DEVICE + float operator()(float lhs) const { return fabs(lhs); } +}; + +template +struct plus { + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + lhs += rhs; + return lhs; + } +}; + +template +struct minus { + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + lhs -= rhs; + return lhs; + } +}; + +template +struct multiplies { + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + lhs *= rhs; + return lhs; + } +}; + +template +struct scale { + T const scaling_factor_; + + CUTLASS_HOST_DEVICE + scale(float scaling_factor) : scaling_factor_(scaling_factor) { + } + + T operator()(T const &rhs) const { + T result = rhs * scaling_factor_; + return result; + } +}; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 +/// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set +template<> +struct plus<__half2> { + CUTLASS_HOST_DEVICE + __half2 operator()(__half2 lhs, __half2 const &rhs) const { + return __hadd2(lhs, rhs); + } +}; + +template<> +struct minus<__half2> { + CUTLASS_HOST_DEVICE + __half2 operator()(__half2 lhs, __half2 const &rhs) const { + return __hsub2(lhs, rhs); + } +}; + +template<> +struct multiplies<__half2> { + CUTLASS_HOST_DEVICE + __half2 operator()(__half2 lhs, __half2 const &rhs) const { + return __hmul2(lhs, rhs); + } +}; + +/// Partial specializations needed when __CUDA_NO_HALF_OPERATORS__ is set +template<> +struct plus<__half> { + CUTLASS_HOST_DEVICE + __half operator()(__half lhs, __half const &rhs) const { + return __hadd(lhs, rhs); + } +}; + +template<> +struct minus<__half> { + CUTLASS_HOST_DEVICE + __half operator()(__half lhs, __half const &rhs) const { + return __hsub(lhs, rhs); + } +}; + +template<> +struct multiplies<__half> { + CUTLASS_HOST_DEVICE + __half operator()(__half lhs, __half const &rhs) const { + return __hmul(lhs, rhs); + } +}; +#endif // defined(__CUDA_ARCH__) + + +/// Squares with optional conversion +template +struct square { + CUTLASS_HOST_DEVICE + Output operator()(T lhs) const { + multiplies mul_op; + + Output y = Output(lhs); + return mul_op(y, y); + } +}; + +/// Returns the magnitude squared of an element. +template +struct magnitude_squared { + CUTLASS_HOST_DEVICE + Output operator()(T lhs) const { + multiplies mul_op; + + Output y = Output(lhs); + return mul_op(y, y); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct square_difference { + CUTLASS_HOST_DEVICE + Output operator()(T lhs, T rhs) const { + multiplies mul_op; + + Output y = Output(lhs) - Output(rhs); + return mul_op(y, y); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference { + CUTLASS_HOST_DEVICE + Output operator()(T lhs, T rhs) const { + multiplies mul_op; + + Output y = Output(lhs) - Output(rhs); + return mul_op(y, y); + } +}; + +// Computes the reciprocal square root +template +struct inverse_square_root; + +template <> +struct inverse_square_root { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs) const { +#if defined(__CUDA_ARCH__) + return rsqrtf(lhs); +#else + return 1.f / std::sqrt(lhs); +#endif + } +}; + +template <> +struct inverse_square_root { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &lhs) const { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ > 520) + auto result = hrsqrt(reinterpret_cast<__half const &>(lhs)); + return reinterpret_cast(result); +#else + return half_t(1.f / std::sqrt(half_t::convert(lhs))); +#endif + } +}; + +/// Divides +template +struct divides { + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + lhs /= rhs; + return lhs; + } +}; + +/// reciprocal_approximate +template +struct reciprocal_approximate { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { + return divides{}(T(1), lhs); + } +}; + +template <> +struct reciprocal_approximate { + CUTLASS_HOST_DEVICE + float operator()(float lhs) const { + float ret; + #if defined(__CUDA_ARCH__) + asm volatile ("rcp.approx.f32 %0, %1;\n" : "=f"(ret) : "f"(lhs)); + #else + ret = 1.0f / lhs; + #endif + return ret; + } +}; + + +template <> +struct reciprocal_approximate { + CUTLASS_HOST_DEVICE + cutlass::float_ue8m0_t operator()(cutlass::float_ue8m0_t lhs) const { + return cutlass::float_ue8m0_t::bitcast(static_cast(static_cast(254u) - lhs.storage)); + } +}; + + +/// reciprocal_approximate with ftz +template +struct reciprocal_approximate_ftz : reciprocal_approximate +{}; + +template <> +struct reciprocal_approximate_ftz { + CUTLASS_HOST_DEVICE + float operator()(float lhs) const { + float ret; + #if defined(__CUDA_ARCH__) + asm volatile ("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(ret) : "f"(lhs)); + #else + if (std::fpclassify(lhs) == FP_SUBNORMAL) { + lhs = 0.0f; + } + ret = 1.0f / lhs; + if (std::fpclassify(ret) == FP_SUBNORMAL) { + ret = 0.0f; + } + #endif + return ret; + } +}; + +/// Negate +template +struct negate { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { + return -lhs; + } +}; + +/// Greater equal +template +struct greater_equal { + CUTLASS_HOST_DEVICE + bool operator()(T const &lhs, T const &rhs) const { + return (lhs >= rhs); + } +}; + +/// Greater +template +struct greater { + CUTLASS_HOST_DEVICE + bool operator()(T const &lhs, T const &rhs) const { + return (lhs > rhs); + } +}; + +/// Less equal +template +struct less_equal { + CUTLASS_HOST_DEVICE + bool operator()(T const &lhs, T const &rhs) const { + return (lhs <= rhs); + } +}; + +/// Less +template +struct less { + CUTLASS_HOST_DEVICE + bool operator()(T const &lhs, T const &rhs) const { + return (lhs < rhs); + } +}; + +template +struct maximum { + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { + if constexpr (PropagateNaN && cutlass::platform::is_floating_point::value) { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + // Call isnan unqualified, so argument-dependent lookup (ADL) + // will find overloads such as cutlass::isnan(half_t). + // Calling ::isnan or std::isnan directly would force + // implicit conversions to float of custom number types + // in the cutlass namespace (e.g., cutlass::half_t). + return lhs > rhs || isnan(lhs) ? lhs : rhs; + } + else { + return (lhs < rhs ? rhs : lhs); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// This is a subclass and not an alias +// in order to work around a known Clang issue, +// where a template template parameter with one template parameter +// does not match classes that take multiple template parameters +// but have defaults for all but the first. +template +struct maximum_with_default_nan_propagation : public maximum +{}; + +template <> +struct maximum { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs, float const &rhs) const { + return fmaxf(lhs, rhs); + } +}; + +template <> +struct maximum { + CUTLASS_HOST_DEVICE + float operator()(float lhs, float rhs) const { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + float res; + asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); + return res; +#else + using CUTLASS_CMATH_NAMESPACE :: isnan; + + return lhs > rhs || isnan(lhs) ? lhs : rhs; +#endif + } +}; + +// This is a subclass and not an alias +// in order to work around a known Clang issue, +// where a template template parameter with one template parameter +// does not match classes that take multiple template parameters +// but have defaults for all but the first. +template +struct maximum_with_nan_propagation : maximum +{}; + +// This alias exists for backwards compatibility only. +// Please use the correctly spelled class template above. +template +using maximum_with_nan_propogation = maximum_with_nan_propagation; + +template +struct minimum { + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { + if constexpr (PropagateNaN && cutlass::platform::is_floating_point::value) { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + return lhs < rhs || isnan(lhs) ? lhs : rhs; + } + else { + return (rhs < lhs ? rhs : lhs); + } + } +}; + +template <> +struct minimum { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs, float const &rhs) const { + return fminf(lhs, rhs); + } +}; + +template <> +struct minimum { + CUTLASS_HOST_DEVICE + float operator()(float lhs, float rhs) const { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + float res; + asm volatile("min.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); + return res; +#else + // No need for ADL; call std::isnan(float) on host and ::isnan(float) on device. + return lhs < rhs || (CUTLASS_CMATH_NAMESPACE :: isnan(lhs)) ? lhs : rhs; +#endif + } +}; + +template +struct minimum_with_nan_propagation : minimum +{}; + +template +struct maximum_absolute_value { + CUTLASS_HOST_DEVICE + float operator()(T const &lhs, T const &rhs) const { + absolute_value_op abs_op; + maximum max_op; + + return max_op(abs_op(lhs), abs_op(rhs)); + } +}; + +// assumes the left operand is already an absolute value +template +struct maximum_absolute_value_reduction { + CUTLASS_HOST_DEVICE + float operator()(T const &lhs, T const &rhs) const { + absolute_value_op abs_op; + maximum max_op; + + return max_op(lhs, abs_op(rhs)); + } +}; + +/// Fused multiply-add +template +struct multiply_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + return C(a) * C(b) + c; + } +}; + +template +struct square_and_plus { + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + multiply_add multiply_add_op; + return multiply_add_op(rhs, rhs, lhs); + } +}; + +// Fused multiply-add that takes exactly one template parameter. +// This is useful for working around a known Clang issue, +// where a template template parameter with one template parameter +// does not match classes that take multiple template parameters +// but have defaults for all but the first. +template +struct homogeneous_multiply_add : public multiply_add +{}; + +/// Fused multiply-add +template +struct multiply_add_relu0 { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + maximum mx; + return mx(C(a) * C(b) + c, C(0)); + } +}; + +/// Guarded-multiply-add +template +struct guarded_multiply_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + if (isnan(a) || isnan(b)) { + return C(0); + } + return C(a) * C(b) + c; + } +}; + +/// Guarded-multiply-add +template <> +struct guarded_multiply_add { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &a, half_t const &b, half_t const &c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + half_t result; + asm ("fma.rn.oob.f16 %0, %1, %2, %3;\n" + : "=h"(*reinterpret_cast(&result)) + : "h"(*reinterpret_cast(&a)), "h"(*reinterpret_cast(&b)), "h"(*reinterpret_cast(&c))); + return result; +#else + // Namespace-qualifying isnan as cutlass::isnan saves the compiler + // the trouble of argument-dependent lookup. Calling std::isnan or + // ::isnan here would result in unwanted implicit conversion to float. + if (cutlass::isnan(a) || cutlass::isnan(b)) { + return half_t(0); + } + return a * b + c; +#endif + } +}; + +/// Guarded-multiply-add-relu0 +template +struct guarded_multiply_add_relu0 { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + using CUTLASS_CMATH_NAMESPACE :: isnan; + + if (isnan(a) || isnan(b)) { + return C(0); + } + maximum mx; + return mx(C(a) * C(b) + c, C(0)); + } +}; + +template <> +struct guarded_multiply_add_relu0 { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &a, half_t const &b, half_t const &c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + half_t result; + asm ("fma.rn.oob.relu.f16 %0, %1, %2, %3;\n" + : "=h"(*reinterpret_cast(&result)) + : "h"(*reinterpret_cast(&a)), "h"(*reinterpret_cast(&b)), "h"(*reinterpret_cast(&c))); + return result; +#else + if (cutlass::isnan(a) || cutlass::isnan(b)) { + return half_t(0); + } + maximum mx; + return mx(a * b + c, half_t(0)); +#endif + } +}; + + +/// Fused and-popc-add +template +struct and_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A and_result = a & b; + int32_t popc_result = detail::popcount(and_result); + return C(popc_result) + c; + } +}; + +/// Fused and-add +template +struct and_add { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b, T const &c) const { + return ((a & b) + c); + } +}; + + + +/// Fused xor-popc-add +template +struct xor_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A xor_result = a ^ b; + int32_t popc_result = detail::popcount(xor_result); + return C(popc_result) + c; + } +}; + +/// Fused xor-add +template +struct xor_add { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b, T const &c) const { + return ((a ^ b) + c); + } +}; + + +/// Fused or-popc-add +template +struct or_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A or_result = a | b; + int32_t popc_result = detail::popcount(or_result); + return C(popc_result) + c; + } +}; + + +/// Fused or-add +template +struct or_add { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b, T const &c) const { + return ((a | b) + c); + } +}; + +namespace detail { + +// Whether namespace-unqualified conj(t) for t of type T is +// well-formed. This says whether the compiler can find +// namespace-unqualified conj(T) via argument-dependent lookup. +// If so, then CUTLASS assumes that conj(t) returns +// the complex conjugate of t. +template +struct has_unqualified_conj : cutlass::platform::false_type +{}; + +template +struct has_unqualified_conj< + T, + decltype(static_cast(conj(cutlass::platform::declval())), void()) + > : cutlass::platform::true_type +{}; + +template +constexpr bool has_unqualified_conj_v = has_unqualified_conj::value; + +} // namespace detail + +// forward declaration (needed for conjugate below) +template +CUTLASS_HOST_DEVICE T conj(T const& z); + +namespace detail { + +// Whether cutlass::conj(t) for t of type T is well-formed. +// If so, then CUTLASS assumes that cutlass::conj(t) +// returns the complex conjugate of t. +template +struct has_cutlass_conj : cutlass::platform::false_type +{}; + +template +struct has_cutlass_conj< + T, + decltype(cutlass::conj(cutlass::platform::declval()), void()) + > : cutlass::platform::true_type +{}; + +template +constexpr bool has_cutlass_conj_v = has_cutlass_conj::value; + +} // namespace detail + +// Return the complex conjugate of the input. +// +// If the struct hasn't already been specialized for type T, then +// +// 1. for arithmetic types, return z; +// +// 2. for types where either (namespace-unqualified) conj(z) or +// cutlass::conj(z) is well formed, declare "using cutlass::conj;" +// and return conj(z); and +// +// 3. for everything else, return z. +// +// Regarding (1), the C++ Standard Library makes std::conj always +// return std::complex, even for (noncomplex) arithmetic types. +// cutlass::conj(T t) needs to return type T. This follows the +// convention of linear algebra software like the BLAS, where +// "conjugate transpose" means the same thing as "transpose" for a +// matrix of noncomplex numbers. +// +// Case (2) covers std::complex, cuda::std::complex, and non-Standard +// (including user-defined) complex number types (for which "conj(z)" +// is findable via argument-dependent lookup). cutlass::conj has a +// totally generic overload, but a more type-specific overload in any +// namespace will take precedence. +// +// Case (3) covers non-Standard non-complex number types. +// +// Users should not generally need to specialize this struct for their +// own custom complex or noncomplex types. The idiomatic way to +// identify a type T as "complex" is to make namespace-unqualified +// calls to conj(T) findable via argument-dependent lookup. +template +struct conjugate { + CUTLASS_HOST_DEVICE + T operator()(T const& z) const { + if constexpr (cutlass::platform::is_arithmetic_v) { + return z; + } + else if constexpr (detail::has_unqualified_conj_v || detail::has_cutlass_conj_v) { + using cutlass::conj; + return conj(z); + } + else { + return z; + } + } +}; + +template +struct first { + CUTLASS_HOST_DEVICE + T operator()(T const & first, T const &...) const { + return first; + } + CUTLASS_HOST_DEVICE + T operator()(T const & first) const { + return first; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct logical_and { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b) const { + return ((static_cast(a) && static_cast(b)) ? T(1) : T()); + } +}; + +template +struct logical_or { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b) const { + return ((static_cast(a) || static_cast(b)) ? T(1) : T()); + } +}; + +template +struct logical_not { + CUTLASS_HOST_DEVICE + T operator()(T const &a) const { + return T(!(a)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct bit_and { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b) const { + return a & b; + } +}; + +template +struct bit_or { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b) const { + return a | b; + } +}; + +template +struct bit_not { + CUTLASS_HOST_DEVICE + T operator()(T const &a) const { + return ~a; + } +}; + +template +struct bit_xor { + CUTLASS_HOST_DEVICE + T operator()(T const &a, T const &b) const { + return a ^ b; + } +}; + +////////////////////////////////////////////////////////////////////////////////////////////////// +/// Atomic reductions + +template +struct atomic_add +{ + CUTLASS_DEVICE + void operator()(T *ptr, const T &data) + { +#if defined(__CUDA_ARCH__) + atomicAdd(ptr, data); +#else + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +template<> +struct atomic_add +{ + CUTLASS_DEVICE + void operator()(double *ptr, const double &data) + { +#if !defined(__CUDA_ARCH__) + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); +#elif (__CUDA_ARCH__ >= 600) + atomicAdd(ptr, data); +#else + // Use CAS loop + unsigned long long int* ptr_int = reinterpret_cast(ptr); + unsigned long long int old_int = *ptr_int; + unsigned long long int assumed_int; + + do { + double update = data + __longlong_as_double(old_int); + assumed_int = old_int; + old_int = atomicCAS(ptr_int, assumed_int, __double_as_longlong(update)); + } while (assumed_int != old_int); +#endif // (__CUDA_ARCH__ >= 600) + } +}; + +template<> +struct atomic_add +{ + CUTLASS_DEVICE + void operator()(half2 *ptr, const half2 &data) + { +#if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)) + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(data); + CUTLASS_NOT_IMPLEMENTED(); +#else + // Vector-2 atomic reduction requires .target sm_60 or higher + uint32_t word = reinterpret_cast(data); + asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word)); +#endif // (__CUDA_ARCH__ >= 600) + } +}; + +template +using red [[deprecated("use atomic_add instead")]] = atomic_add; + +template +struct atomic_maximum { + CUTLASS_DEVICE + T operator()(T *ptr, T value) const { +#if defined(__CUDA_ARCH__) + return atomicMax(ptr, value); +#else + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(value); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + +template <> +struct atomic_maximum { + CUTLASS_DEVICE + float operator()(float *ptr, float value) const { +#if defined(__CUDA_ARCH__) + // In device code, make sure that we do NOT try to use + // std::signbit, as that won't work if building with NVRTC. + // Instead, prefix "::" to call signbit from the global namespace, + // which CUDA guarantees to work in device code without including + // any headers. + // + return ! ::signbit(value) ? + __int_as_float(atomicMax((int*)ptr, __float_as_int(value))) : + __uint_as_float(atomicMin((unsigned int*)ptr, __float_as_uint(value))); +#else + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(value); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + +// is_atomic +template +struct is_atomic : platform::false_type {}; +template +struct is_atomic> : platform::true_type {}; +template +struct is_atomic> : platform::true_type {}; + + +////////////////////////////////////////////////////////////////////////////////////////////////// +/// Parallel Synchronization and Communication Instructions +template +struct redux_abs_max_nan_propagation_sync_warp; + +template <> +struct redux_abs_max_nan_propagation_sync_warp { + CUTLASS_DEVICE + float operator()(float const &lhs) const { +#if defined(CUTLASS_ARCH_CREDUX_ENABLED) + float result; + asm volatile("redux.sync.max.abs.NaN.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(lhs)); + return result; +#elif defined(__CUDA_ARCH__) + cutlass::maximum max_op; + int shuffle_width = 32; + float abs_max = cutlass::absolute_value_op{}(lhs); + CUTLASS_PRAGMA_UNROLL + for(int offset = shuffle_width / 2; offset > 0; offset /= 2) { + float value = __shfl_down_sync(0xffffffff, abs_max, offset, shuffle_width); + abs_max = max_op(abs_max,value); + } + // Broadcast the maximum to all threads participating in the reduction. + abs_max = __shfl_sync(0xffffffff, abs_max, 0, shuffle_width); + return abs_max; +#else + CUTLASS_UNUSED(lhs); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + +template +struct redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31; + +template <> +struct redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31{ + CUTLASS_DEVICE + float operator()(float const &max) const { +#if defined(CUTLASS_ARCH_CREDUX_ENABLED) + int half_warp_idx = threadIdx.x / (NumThreadsPerWarp / 2); + bool first_half_threads = (half_warp_idx % 2) == 0; + float value0 = first_half_threads ? max : 0; + float v0 = cutlass::redux_abs_max_nan_propagation_sync_warp{}(value0); + + float value1 = !first_half_threads ? max : 0; + float v1 = cutlass::redux_abs_max_nan_propagation_sync_warp{}(value1); + return first_half_threads ? v0: v1; + +#elif defined(__CUDA_ARCH__) + float abs_max = cutlass::absolute_value_op{}(max); + cutlass::maximum max_op; + constexpr int shuffle_width = 16; + CUTLASS_PRAGMA_UNROLL + for(int offset = shuffle_width/2; offset > 0; offset /= 2) { + float value = __shfl_down_sync(0xffffffff, abs_max, offset, shuffle_width); + abs_max = max_op(abs_max,value); + } + // Broadcast the maximum to all threads participating in the reduction. + abs_max = __shfl_sync(0xffffffff, abs_max, 0, shuffle_width); + return abs_max; +#else + CUTLASS_UNUSED(max); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for nvcuda::wmma::fragment +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +template +struct plus> +{ + using Fragment = nvcuda::wmma::fragment; + using ElementType = typename Fragment::element_type; + + CUTLASS_HOST_DEVICE + Fragment operator()(Fragment const &lhs, Fragment const &rhs) const + { + Fragment result; + plus scalar_op; + + ElementType *result_elts = reinterpret_cast(&result); + const ElementType *lhs_elts = reinterpret_cast(&lhs); + const ElementType *rhs_elts = reinterpret_cast(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Fragment::num_elements; i++) { + result_elts[i] = scalar_op(lhs_elts[i], rhs_elts[i]); + } + + return result; + } +}; + +#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) + + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_builder.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..83a65059af41edf89fd4b977e6973a3e6d612ea5 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_builder.hpp @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_mma_decl.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" +#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" +#include "cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl" +#if !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_simt_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm120_mma_builder.inl" +#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" +#include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl" +#include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl" +#include "cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl" +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aae73348b5a205494a7f7c2ee0407bd67a5b42a3 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_builder_decl.hpp @@ -0,0 +1,100 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify stage counts or dispatch to automatic computation of stage count +template +struct StageCount { + static constexpr int value = num_stages; + + StageCount() = default; + explicit StageCount(cute::Int) {} +}; + +template +struct StageCountAutoCarveout { + static constexpr int bytes = carveout_bytes; + + StageCountAutoCarveout() = default; + explicit StageCountAutoCarveout(cute::Int) {} +}; + +namespace detail { + +// Forward Declaration +template +constexpr int +compute_carveout_from_epi(); + +} // namespace detail + +template +struct StageCountAutoCarveoutEpi : StageCountAutoCarveout()> {}; + +using StageCountAuto = StageCountAutoCarveout<0>; + +// Used to automatically let the builder pick the kernel schedule. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct KernelScheduleAuto final {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_mma.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_mma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e3ae8003794507f9c9d7183c388fcf6074a40eb --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_mma.hpp @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/collective/sm70_mma_twostage.hpp" +#include "cutlass/gemm/collective/sm80_mma_multistage.hpp" +#include "cutlass/gemm/collective/sm80_mma_array_multistage.hpp" +#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" +#if !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp" +#include "cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm120_mma_tma.hpp" +#include "cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp" +#include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp" +#include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp" +#include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp" +#include "cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp" +#endif // !defined(__CUDACC_RTC__) + + + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a2faa1ff28e0fc52491937fd003396fca1ffe646 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/collective_mma_decl.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class TiledMma, + class GmemTiledCopyA, + class SmemLayoutAtomA, + class SmemCopyAtomA, + class TransformA, + class GmemTiledCopyB, + class SmemLayoutAtomB, + class SmemCopyAtomB, + class TransformB +> +struct CollectiveMma { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6ff3a94478fa1916b77938d2ca77178ef7d6bc43 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/fp8_accumulation.hpp @@ -0,0 +1,279 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/algorithm/clear.hpp" +#include "cute/tensor.hpp" + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////FP8 Accumulation/////////////////////////// +////////////////////////////////////////////////////////////////////////////// +/// This class provides API to promote (add) or scale (multiply_add) the results +/// from the tensor core accumulators to the main accumulators when the number +/// of MMAs reaches the max number of MMA interval specified by user, after that +/// the tensor core accumulators are zeroed. +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class EngineAccum, + class LayoutAccum> +struct GmmaFP8Accumulation { + using TensorAccum = cute::Tensor; + using ElementAccumulator = typename EngineAccum::value_type; + + static_assert(is_static::value, "Accumulator Layout should be static"); + static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + +private: + TensorAccum& accum_; + TensorAccum accum_temp_; + + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. + uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + + // promote or `add` the partial accumulators to main accumulator (FADD). + CUTLASS_DEVICE + void promote_core() { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i); + } + } + + // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + CUTLASS_DEVICE + void scale_core(ElementAccumulator const &scale) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale; + } + } + + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scale) { + using TensorScale = cute::Tensor; + + static_assert(is_static::value, "Scale Layout should be static"); + static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale(i); + } + } + + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + using TensorScaleA = cute::Tensor; + using TensorScaleB = cute::Tensor; + + static_assert(is_static::value, "ScaleA Layout should be static"); + static_assert(is_static::value, "ScaleB Layout should be static"); + static_assert(is_rmem::value, "ScaleA tensor must be rmem resident."); + static_assert(is_rmem::value, "ScaleB tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScaleA{}.shape(), "Accumulator and scaleA must have same shape."); + static_assert(LayoutAccum{}.shape() == LayoutScaleB{}.shape(), "Accumulator and scaleB must have same shape."); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scaleA(i) * scaleB(i); + } + } + +public: + CUTLASS_DEVICE + GmmaFP8Accumulation( + TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) + : accum_(accum), + accum_promotion_interval_(accum_promotion_interval), + mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), + mma_count_(0), + reset_accum_flag_(0) + { + accum_temp_ = cute::make_fragment_like(accum); + } + + // + // Methods (Common) + // + + CUTLASS_DEVICE + TensorAccum& operator()() { + return accum_temp_; + } + + /// prepare the MMA accumulators when initialization or zeroing is required. + CUTLASS_DEVICE + bool prepare_if_needed() { + return reset_accum_flag_; + } + + // + // Methods (for FADD version) + // + + /// promote (add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_if_needed() { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + promote_core(); + mma_count_ = 0; + } + } + + /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_residue_if_needed() { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + promote_core(); + } + } + + // + // Methods (for FFMA version) + // + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void scale_if_needed(ElementAccumulator const &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scaleA, scaleB); + mma_count_ = 0; + } + } + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator without checking the counter. + CUTLASS_DEVICE + void scale(ElementAccumulator const &scale) { + scale_core(scale); + } + + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale(const cute::Tensor &scale) { + scale_core(scale); + } + + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + scale_core(scaleA, scaleB); + } + + /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void scale_residue_if_needed(ElementAccumulator const &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } + + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } + + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scaleA, scaleB); + } + } +}; + +} // namespace cutlass::gemm::collective diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2665ef1c2e894f7f937700f5d18902c122147bfb --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -0,0 +1,1322 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = cute::conditional_t; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFA; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementSF const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + InternalLayoutSFA layout_SFA; + InternalLayoutSFB layout_SFB; + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + layout_SFA = args.layout_SFA; + layout_SFB = args.layout_SFB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = nullptr; + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + args.layout_SFA, + reinterpret_cast(args.ptr_SFB), + args.layout_SFB, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Represent the full tensor of Scale factors + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + if constexpr (IsGroupedGemmKernel) { + layout_SFA = params.layout_SFA[init_group]; + layout_SFB = params.layout_SFB[init_group]; + } + else { + layout_SFA = params.layout_SFA; + layout_SFB = params.layout_SFB; + } + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapA, class TensorMapB, + class TensorMapSFA, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto [tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + Tensor pSFA_tensormap = make_tensor(observed_tma_load_sfa_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{}); + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + + copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_sfa, tma_desc_sfb); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + mainloop_params.ptr_SFA[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_SFA = {1,1,1,1,1}; + cute::array prob_stride_SFA = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + ElementSF const* ptr_SF = nullptr; + Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfa_, tensor_sfa, + prob_shape_SFA, prob_stride_SFA); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFA) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + prob_shape_SFA, + prob_stride_SFA); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_SFA); + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..344de4d33ba04dbf2d147a035614c8445fca8d25 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,1043 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> { + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + using TiledMma_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = remove_cvref_t(StridePairA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ATmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ATmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + cluster_layout_sfb_vmnk); + + return { + tma_load_a, + tma_load_sfa, + tma_load_sfb, + args.ptr_B, + args.dB, + args.layout_SFA, + args.layout_SFB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + // static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + // Check for SFA SFB layout requirement + const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + implementable = implementable && (layout_sfa_ref == args.layout_SFA); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); + } + + implementable = implementable && (layout_sfb_ref == args.layout_SFB); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMma_SF{}.get_slice(blockIdx.x % size(typename TiledMma_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB // for input scale factor tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // convert to subptr iterator if necessary + auto ptr_B = recast_ptr(params.ptr_B); + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, + tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + // class KTileCount, + // class GTensorPartitionedA, + // class STensorA, + class TileCoordMNKL, + class KTileIterator, + class... TLoadParams // see load_init_tma + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + // KTileCount k_tiles = get<0>(load_inputs); + // GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + // STensorA tAsA = get<2>(load_inputs); + + auto [k_tiles, + tAgA_mkl, tAsA, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + // auto [M,N,K,L] = problem_shape_MNKL; + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class CtaTileCoord, + class... TMmaParams + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage_tma), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage_tma), thr_tCtSFB_s2t); + } + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage_tma), + tCrB(_,_,k_block,read_stage_cpasync), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + // ClusterShape cluster_shape_; + // uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..79a97bed9a5b7d886fce70841c439504fda6cadb --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -0,0 +1,1104 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = cute::conditional_t; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + template < + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + STensorA tAsA; + STensorB tBsB; + // for scale factor tensor values + GTensorPartitionedSFA tAgSFA_mkl; + GTensorPartitionedSFB tBgSFB_nkl; + STensorSFA tAsSFA; + STensorSFB tBsSFB; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + uint16_t mcast_mask_sfa; + uint16_t mcast_mask_sfb; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + GTensorPartitionedSFA tAgSFA_mkl_, GTensorPartitionedSFB tBgSFB_nkl_, + STensorSFA tAsSFA_, STensorSFB tBsSFB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_, + uint16_t mcast_mask_sfa_, uint16_t mcast_mask_sfb_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , tAgSFA_mkl(tAgSFA_mkl_), tBgSFB_nkl(tBgSFB_nkl_) + , tAsSFA(tAsSFA_), tBsSFB(tBsSFB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) + , mcast_mask_sfa(mcast_mask_sfa_), mcast_mask_sfb(mcast_mask_sfb_) {} + }; + + template < + class TiledMma, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + struct MmaParams { + TiledMma tiled_mma; + FragmentA tCrA; + FragmentB tCrB; + FragmentSFA tCtSFA; + FragmentSFB tCtSFB; + SFATiledCopy tiled_copy_s2t_SFA; + SmemFrgSFA thr_tCsSFA_s2t; + TmemFrgSFA thr_tCtSFA_s2t; + SFBTiledCopy tiled_copy_s2t_SFB; + SmemFrgSFB thr_tCsSFB_s2t; + TmemFrgSFB thr_tCtSFB_s2t; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_, FragmentSFA tCtSFA_, FragmentSFB tCtSFB_, + SFATiledCopy tiled_copy_s2t_SFA_, SmemFrgSFA thr_tCsSFA_s2t_, TmemFrgSFA thr_tCtSFA_s2t_, + SFBTiledCopy tiled_copy_s2t_SFB_, SmemFrgSFB thr_tCsSFB_s2t_, TmemFrgSFB thr_tCtSFB_s2t_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_), tCtSFA(tCtSFA_), tCtSFB(tCtSFB_) + , tiled_copy_s2t_SFA(tiled_copy_s2t_SFA_), thr_tCsSFA_s2t(thr_tCsSFA_s2t_), thr_tCtSFA_s2t(thr_tCtSFA_s2t_) + , tiled_copy_s2t_SFB(tiled_copy_s2t_SFB_), thr_tCsSFB_s2t(thr_tCsSFB_s2t_), thr_tCtSFB_s2t(thr_tCtSFB_s2t_) {} + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + // Check for SFA SFB layout requirement + const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + implementable = implementable && (layout_sfa_ref == args.layout_SFA); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); + } + + implementable = implementable && (layout_sfb_ref == args.layout_SFB); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + return LoadParams{ + size<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb}; // multicast masks + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return MmaParams{ + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bcf88620c589fcd452840e1fa1fea798b23dd5d1 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp @@ -0,0 +1,1321 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/builders/sm1xx_sparse_config.inl" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class LayoutPairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedBlockScaledSparse< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + LayoutPairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedBlockScaledSparse< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + // CtaK needs to be multiplier of SFAtomK + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using SfAtomK = cute::Int(SfAtom{})>; + static_assert( shape<2>(CtaShape_MNK{}) % SfAtomK{} == 0, "CtaK needs to be multiplier of SFAtomK"); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + static_assert(get<0,0>(MmaShapeA_MK{}) == 128 && + (get<2>(MmaShapeA_MK{}) == 2 || get<2>(MmaShapeA_MK{}) == 4), + "This kernel only support MmaShape=128 and 2/4 kphase."); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using LayoutPairA = LayoutPairA_; + using StridePairB = StridePairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A, B, and E matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutA = remove_cvref_t(LayoutPairA{}))>; + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using LayoutE = remove_cvref_t(LayoutPairA{}))>; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using ElementBMma = typename TiledMma::ValTypeB; + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(LayoutPairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(is_sparse::value, "ElementAMma is sparse"); + static_assert(!is_sparse::value, "ElementA is not sparse"); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + // LayoutA is nested in the stride due to the sparsity. + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{})), Int>; + + using SparseConfig = cutlass::Sm1xxGemmSparseConfig, + ElementEMma>; + static constexpr int ElementASparsity = 2; // typename SparseConfig::ElementASparsity{}; + + // The offline permutation for the metadata. + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + // Metadata pathways + using GmemCopyAtomE = GmemTiledCopyA; + + using MainloopPipeline = cutlass::PipelineTmaSparseUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static constexpr int UtccpReuseCnt = ((size<2>(TileShape{}) / typename SparseConfig::TensorEAtomK{}) == 0) ? + typename SparseConfig::TensorEAtomK{} / size<2>(TileShape{}) : 1; + static_assert(UtccpReuseCnt == 1 || UtccpReuseCnt == 2, "UTCCP reuse count can only be either one or two"); + // (TileM, TileN, TileK) TileK is adjusted according to the reuse. + using TileShapeE = decltype(replace<2>(TileShape{}, cute::lcm(size<2>(TileShape{}), typename SparseConfig::TensorEAtomK{}))); + using MmaShapeE_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShapeE{}), size<2>(TileShapeE{})))); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomE{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomE{})) == 0, "SmemLayoutAtomE must evenly divide the tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) that one UTCCP instruction can provide + using SmemLayoutE = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomE{}, + append(MmaShapeE_MK{}, Int{}))); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static_assert(rank(SmemLayoutAtomE{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomE{})) == 0, "SmemLayoutAtomE must evenly divide tile shape."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_sparse_f8f6f4(); + + using TmaInternalElementA = cute::sparse_elem>; + using TmaInternalElementB = cute::conditional_t; + + using SmemAllocTypeA = cute::sparse_elem < 8, + uint8_t, + ElementAMmaRaw>>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + // Kernel Input Data Type that consider runtime dtype + using ArrayElementA = cute::conditional_t>, + ElementA>; + using ArrayElementB = cute::conditional_t>, + ElementB>; + + using RuntimeDataTypeA = cute::conditional_t, + void*>; + + using RuntimeDataTypeB = cute::conditional_t, + void*>; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_E; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t MetadataTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutE{})) * cute::sizeof_bits_v); + static constexpr uint32_t MainLoadTmaTransactionBytes = SFTransactionBytes + ABTmaTransactionBytes; + + template < + class AccTensor, + class ETensor, class SfaTensor, class SfbTensor + > + struct TmemStorage { + AccTensor accumulators; + ETensor tCtE; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + template < + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, class GTensorPartitionedE, + class STensorA, class STensorB, class STensorE, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + GTensorPartitionedE tEgE_nkl; + STensorA tAsA; + STensorB tBsB; + STensorE tEsE; + GTensorPartitionedSFA tAgSFA_mkl; + GTensorPartitionedSFB tBgSFB_nkl; + STensorSFA tAsSFA; + STensorSFB tBsSFB; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + uint16_t mcast_mask_e; + uint16_t mcast_mask_sfa; + uint16_t mcast_mask_sfb; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, GTensorPartitionedE tEgE_nkl_, + STensorA tAsA_, STensorB tBsB_, STensorE tEsE_, + GTensorPartitionedSFA tAgSFA_mkl_, GTensorPartitionedSFB tBgSFB_nkl_, + STensorSFA tAsSFA_, STensorSFB tBsSFB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_, uint16_t mcast_mask_e_, + uint16_t mcast_mask_sfa_, uint16_t mcast_mask_sfb_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_), tEgE_nkl(tEgE_nkl_) + , tAsA(tAsA_), tBsB(tBsB_), tEsE(tEsE_) + , tAgSFA_mkl(tAgSFA_mkl_), tBgSFB_nkl(tBgSFB_nkl_) + , tAsSFA(tAsSFA_), tBsSFB(tBsSFB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_), mcast_mask_e(mcast_mask_e_) + , mcast_mask_sfa(mcast_mask_sfa_), mcast_mask_sfb(mcast_mask_sfb_) {} + }; + + template < + class TiledMma, + class FragmentA, class FragmentB, + class FragmentE, class ETiledCopy, class SmemFrgE, class TmemFrgE, + class FragmentSFA, class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class FragmentSFB, class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + struct MmaParams { + TiledMma tiled_mma; + // A + FragmentA tCrA; + // B + FragmentB tCrB; + // E + FragmentE tCtE; + ETiledCopy tiled_copy_s2t_E; + SmemFrgE thr_tCsE_s2t; + TmemFrgE thr_tCtE_s2t; + // SFA + FragmentSFA tCtSFA; + SFATiledCopy tiled_copy_s2t_SFA; + SmemFrgSFA thr_tCsSFA_s2t; + TmemFrgSFA thr_tCtSFA_s2t; + // SFB + FragmentSFB tCtSFB; + SFBTiledCopy tiled_copy_s2t_SFB; + SmemFrgSFB thr_tCsSFB_s2t; + TmemFrgSFB thr_tCtSFB_s2t; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_, + FragmentE tCtE_, ETiledCopy tiled_copy_s2t_E_, + SmemFrgE thr_tCsE_s2t_, TmemFrgE thr_tCtE_s2t_, + FragmentSFA tCtSFA_, SFATiledCopy tiled_copy_s2t_SFA_, + SmemFrgSFA thr_tCsSFA_s2t_, TmemFrgSFA thr_tCtSFA_s2t_, + FragmentSFB tCtSFB_, SFBTiledCopy tiled_copy_s2t_SFB_, + SmemFrgSFB thr_tCsSFB_s2t_, TmemFrgSFB thr_tCtSFB_s2t_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_) + , tCtE(tCtE_), tiled_copy_s2t_E(tiled_copy_s2t_E_) + , thr_tCsE_s2t(thr_tCsE_s2t_), thr_tCtE_s2t(thr_tCtE_s2t_) + , tCtSFA(tCtSFA_), tiled_copy_s2t_SFA(tiled_copy_s2t_SFA_) + , thr_tCsSFA_s2t(thr_tCsSFA_s2t_), thr_tCtSFA_s2t(thr_tCtSFA_s2t_) + , tCtSFB(tCtSFB_), tiled_copy_s2t_SFB(tiled_copy_s2t_SFB_) + , thr_tCsSFB_s2t(thr_tCsSFB_s2t_), thr_tCtSFB_s2t(thr_tCtSFB_s2t_) {} + }; + + // Host side kernel arguments + struct Arguments { + // A is A Compressed, not raw tensorA + ArrayElementA const* ptr_A{nullptr}; + LayoutA layout_a{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementE const* ptr_E{nullptr}; + LayoutE layout_e{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_E = decltype(make_tma_atom_A_sm100( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + make_tensor(recast_ptr(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,_,cute::Int<0>{}), + TileShapeE{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_E tma_load_e; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_E tma_load_e_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutA layout_a; + LayoutE layout_e; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , layout_a_(params.layout_a) + , layout_e_(params.layout_e) + , layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_e_ = is_fallback_cluster ? ¶ms.tma_load_e_fallback : ¶ms.tma_load_e; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_e_ = ¶ms.tma_load_e; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + auto ptr_E = recast_ptr(args.ptr_E); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + // Cluster layout for TMA construction of SFB + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_E tma_load_e = make_tma_atom_A_sm100( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,_,cute::Int<0>{}), + TileShapeE{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_E tma_load_e_fallback = make_tma_atom_A_sm100( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,_,cute::Int<0>{}), + TileShapeE{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMMA_SF{}, + cluster_layout_sfb_vmnk_fallback); + + return { + tma_load_a, + tma_load_e, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_e_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_a, + args.layout_e, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + // Check for Alignment Requirement + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits_v; + + bool implementable = true; + // Check Alignment A + if constexpr (is_A_mn_major) { + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K/2, L), + cute::make_stride(_1{}, M, M*K/2)); + } + else { // If A is K-major + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K/2, L), + cute::make_stride(K/2, _1{}, M*K/2)); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA on tensorA\n"); + } + + // Check Alignment B + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits_v; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA on tensorB\n"); + } + + // Check for AB layout requirement + const auto layout_a_ref = SparseConfig::fill_layoutA(problem_shape_MNKL); + const auto layout_e_ref = SparseConfig::fill_layoutE(problem_shape_MNKL); + implementable = implementable && (layout_a_ref == args.layout_a); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_a mismatch\n"); + } + + implementable = implementable && (layout_e_ref == args.layout_e); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_e mismatch\n"); + } + + // Check for SFA SFB layout requirement + const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + implementable = implementable && (layout_sfa_ref == args.layout_SFA); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); + } + + implementable = implementable && (layout_sfb_ref == args.layout_SFB); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); + } + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 32 || + (SFVecSize == 64 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=64) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_e_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + Tensor tCtE = make_tensor(take<0,3>(shape(SmemLayoutE{}))); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + tmem_storage.tCtE = tCtE; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtE.data() = tmem_base_addr + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFA.data() = tmem_storage.tCtE.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtE); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tAgSFA_mkl - partitioned gmem tensor for SFA + /// tBgSFB_nkl - partitioned gmem tensor for SFB + /// tAsSFA - partitioned tmem tensor for SFA + /// tAsSFB - partitioned tmem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(layout_a_.shape()); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + Tensor mE_mkl = observed_tma_load_e_->get_tma_tensor(layout_e_.shape()); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + Tensor gE_mkl = local_tile(mE_mkl, TileShapeE{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + Tensor tCgE_mkl = cta_mma.partition_A(gE_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (MMA,MMA_M,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tEgE_mkl, tEsE] = tma_partition(*observed_tma_load_e_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sE), group_modes<0,3>(tCgE_mkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + uint16_t mcast_mask_e = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + + return LoadParams{ + size<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tEgE_mkl, tAsA, tBsB, tEsE, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_e, mcast_mask_sfa, mcast_mask_sfb}; // multicast masks + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A B E matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (MMA,MMA_M,MMA_K,PIPE) that one UTCCP can provide + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sE)); // PIPE + + Tensor tCtE = tmem_storage.tCtE; + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpEOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + cute::SM100_UTCCP_128dp128bit_2cta, cute::SM100_UTCCP_128dp128bit_1cta>; + auto tiled_copy_s2t_E = make_utccp_copy(UtccpEOp{}, recast(tCtE)); + + auto thr_copy_s2t_E = tiled_copy_s2t_E.get_slice(0); + Tensor thr_tCsE_s2t_ = thr_copy_s2t_E.partition_S(recast(sE)); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + Tensor thr_tCsE_s2t = get_utccp_smem_desc_tensor(thr_tCsE_s2t_); + Tensor thr_tCtE_s2t = thr_copy_s2t_E.partition_D(recast(tCtE)); + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_s2t_); + auto thr_tCtSFA_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_s2t_); + auto thr_tCtSFB_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return MmaParams{ + tiled_mma, + tCrA, tCrB, + tCtE, tiled_copy_s2t_E, thr_tCsE_s2t, thr_tCtE_s2t, + tCtSFA, tiled_copy_s2t_SFA, thr_tCsSFA_s2t, thr_tCtSFA_s2t, + tCtSFB, tiled_copy_s2t_SFB, thr_tCsSFB_s2t, thr_tCtSFB_s2t}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [k_tiles, + tAgA_mkl, tBgB_nkl, tEgE_mkl, tAsA, tBsB, tEsE, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_b, mcast_mask_e, + mcast_mask_sfa, mcast_mask_sfb] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tEgE = tEgE_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + copy(observed_tma_load_e_->with(*tma_barrier, mcast_mask_e), tEgE(_,*k_tile_iter), tEsE(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, + tCrA, tCrB, + tCtE, tiled_copy_s2t_E, thr_tCsE_s2t, thr_tCtE_s2t, + tCtSFA, tiled_copy_s2t_SFA, thr_tCsSFA_s2t, thr_tCtSFA_s2t, + tCtSFB, tiled_copy_s2t_SFB, thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtE(_,_,k_block), + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtE(_,_,k_block), + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_E const* observed_tma_load_e_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutA layout_a_; + LayoutE layout_e_; + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d832a1fc4f3ae135ed32d10b266b34381cecee47 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp @@ -0,0 +1,894 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // Tensor pointers will be fixed before the first access + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + return partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..812553afc959e972df280e767ab1de1b558634fc --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp @@ -0,0 +1,1342 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using InternalLayoutSFA = cute::remove_pointer_t; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + static constexpr int ScaleGranularityM = size<0,0>(InternalLayoutSFA{}); + + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0 and ScaleGranularityM <= size<0>(TileShape{}), "Scale Granularity M must divide Tile Shape"); + + static constexpr int ScaleGranularityN = size<0,0>(InternalLayoutSFB{}); + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0 and ScaleGranularityN <= size<1>(TileShape{}), "Scale Granularity N must divide Tile Shape"); + + static_assert(size<1, 0>(InternalLayoutSFA{}) == size<1, 0>(InternalLayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + + static constexpr int ScaleGranularityK = size<1, 0>(InternalLayoutSFA{}); + static constexpr int ScaleKsPerTile = size<2>(TileShape{}) / ScaleGranularityK; + static_assert(size<2>(TileShape{}) % ScaleGranularityK == 0 and ScaleGranularityK <= size<2>(TileShape{}), "Scale Granularity K must divide Tile Shape"); + static_assert(ScaleGranularityK % size<2>(typename TiledMma::AtomShape_MNK{}) == 0, "Scale Granularity K must be divisible by MMA_K"); + + static constexpr int K_BLOCK_MMAS_PER_SCALE_K = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); + + static_assert(size<0>(CtaShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<1>(CtaShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<2>(CtaShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(InternalLayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(CtaShape_MNK{})); + using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(CtaShape_MNK{})); + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopySFA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopyB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using GmemTiledCopySFB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int CopyAlignmentSFA = GmemTiledCopySFA::AtomNumVal::value * sizeof(typename GmemTiledCopySFA::ValType) / sizeof(ElementAccumulator); + static constexpr int CopyAlignmentSFB = GmemTiledCopySFB::AtomNumVal::value * sizeof(typename GmemTiledCopySFB::ValType) / sizeof(ElementAccumulator); + + static constexpr int AlignmentSFA = CopyAlignmentSFA * (GmemTiledCopySFA::AtomNumVal::value > 1 ? + (size<0,1>(InternalLayoutSFA{}.stride()) == 1 ? ScaleGranularityM : ScaleGranularityK) : 1); + static constexpr int AlignmentSFB = CopyAlignmentSFB * (GmemTiledCopySFB::AtomNumVal::value > 1 ? + (size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? ScaleGranularityN : ScaleGranularityK) : 1); + + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineAsync; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync< + AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + // Two arrivals per thread in the warp (1 arrival and 1 arrival through cp.async.mbarrier) + static constexpr int NumMainloopSFProducerThreadEvents = 64; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = uint_bit_t>; + using BitTypeElementB = uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + using SmemLayoutScaleA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutScaleB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + struct PipelineStorage { + alignas(16) PipelineABStorage pipeline_ab; + alignas(16) PipelineSFStorage pipeline_sf; + alignas(16) AccumulatorPipelineStorage pipeline_accum; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + + ElementAccumulator const** ptr_SFA; + LayoutSFA layout_SFA; + ElementAccumulator const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // Tensor pointers will be fixed before the first access + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + args.ptr_SFA, + args.layout_SFA, + args.ptr_SFB, + args.layout_SFB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + bool implementable_sf = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + implementable_sf = implementable_sf && cutlass::detail::check_alignment(ScaleConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + implementable_sf = implementable_sf && cutlass::detail::check_alignment(ScaleConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + if (!implementable_sf) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for Scale Factors.\n"); + } + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + implementable = implementable && implementable_sf; + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + [[maybe_unused]] int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-Cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + int current_group) const { + return load_sf_update(problem_shape_MNKL, params, shared_tensors, current_group); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + template + CUTLASS_DEVICE auto + load_sf_update( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + int current_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + const int32_t mock_L = 1; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L)); + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + auto layout_SFA = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsGroupedGemmKernel) { + return params.layout_SFA[current_group]; + } + else { + return params.layout_SFA; + } + }(); + + auto layout_SFB = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsGroupedGemmKernel) { + return params.layout_SFB[current_group]; + } + else { + return params.layout_SFB; + } + }(); + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(params.ptr_SFA[current_group]), layout_SFA); // (m,k,l) + + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(params.ptr_SFB[current_group]), layout_SFB); // (n,k,l) + + Tensor SFA_mkl_ident = make_identity_tensor(shape(layout_SFA)); + + Tensor SFB_nkl_ident = make_identity_tensor(shape(layout_SFB)); + + // Tile the tensors and defer the slice + Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + static_assert(rank(decltype(gSFA_mkl){}) == 5); + static_assert(rank(decltype(gSFB_nkl){}) == 5); + + // 1 thread copies entire set of scalar + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) + + Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) + Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) + + Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) + Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) + Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); + + static_assert(rank(decltype(tSFAgSFA_mkl){}) == 6); + static_assert(rank(decltype(tSFBgSFB_nkl){}) == 6); + + return cute::make_tuple(gA_mkl, + tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAsSFA, tSFBsSFB, + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, + layout_SFA, layout_SFB); + } + + /// Setup data needed for transform + CUTLASS_DEVICE auto + accum_init( + TensorStorage& shared_tensors) const { + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) + + return cute::make_tuple(sSFA, sSFB); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::Tensor const& accumulators, + TensorStorage& shared_tensors, + [[maybe_unused]] uint32_t const tmem_nonaccum_offset) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA_ = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB_ = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(rank(tCrA_) == _4{}); + + auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), + get<1>(shape(tCrA_.layout())), + Int{}, + _1{}); + + auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), + get<1>(shape(tCrB_.layout())), + Int{}, + _1{}); + + Tensor tCrA = flat_divide(tCrA_, + mma_tile_shape_A)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_M,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + Tensor tCrB = flat_divide(tCrB_, + mma_tile_shape_B)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_N,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline mainloop_ab_pipeline, + MainloopABPipelineState mainloop_ab_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(input_tensormaps); + } + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_ab_pipeline.producer_try_acquire(mainloop_ab_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_ab_pipeline.producer_acquire(mainloop_ab_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_ab_pipeline.producer_get_barrier(mainloop_ab_pipe_producer_state); + + int write_stage = mainloop_ab_pipe_producer_state.index(); + ++mainloop_ab_pipe_producer_state; + barrier_token = mainloop_ab_pipeline.producer_try_acquire(mainloop_ab_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_ab_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_ab_tail(MainloopABPipeline mainloop_ab_pipeline, MainloopABPipelineState mainloop_ab_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_ab_pipeline.producer_tail(mainloop_ab_pipe_producer_state); + } + + /// Perform a collective-scoped transform + /// Producer Perspective + template < + class UnusedGTensorA, + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class IdentPartitionedSFA, class IdentPartitionedSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple const& mainloop_sf_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused, tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAsSFA, tSFBsSFB, + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, + layout_SFA, layout_SFB] = mainloop_sf_inputs; + + // slice out the work coord from partitioned tensors + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; + + Tensor tSFAgSFA = tSFAgSFA_mkl(_, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor tSFBgSFB = tSFBgSFB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor thr_tile_SFA_k = tSFAIdentSFA_mkl(_0{}, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor thr_tile_pSFA = make_tensor(shape(filter_zeros(thr_tile_SFA_k(_,_,_0{}), tSFAgSFA(_0{},_,_,_0{}).stride()))); + Tensor thr_tile_SFB_k = tSFBIdentSFB_nkl(_0{}, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor thr_tile_pSFB = make_tensor(shape(filter_zeros(thr_tile_SFB_k(_,_,_0{}), tSFBgSFB(_0{},_,_,_0{}).stride()))); + + // Issue the loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK pipe_producer_state for _writing_ + mainloop_sf_pipeline.producer_acquire(mainloop_sf_pipe_producer_state); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFA); ++i) { + Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); + thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFB); ++i) { + Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); + thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); + } + + copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); + copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); + mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); + + __syncwarp(); + + ++mainloop_sf_pipe_producer_state; + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_sf_tail( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_sf_pipeline.producer_tail(mainloop_sf_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 4, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N, P)"); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + CUTLASS_PRAGMA_UNROLL + for (int scale_k_iter = 0; scale_k_iter < size<3>(tCrA); ++scale_k_iter) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + auto acc = slice_accumulator(accumulators, accumulator_pipe_producer_state.index()); + static_assert(is_tmem>::value, "Accumulator must be tmem resident."); + static_assert(rank(remove_cvref_t{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + // for each set of scale_k_iter we zero the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,scale_k_iter,read_stage), + tCrB(_,_,k_block,scale_k_iter,read_stage), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + + } + + return make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state); + + } + + /// Transform + template < + class FrgEngine, + class FrgLayout, + class TensorsSFA, + class TensorsSFB, + class CtaTileCoord, + class CopyOpT2R, + class EpilogueTile + > + CUTLASS_DEVICE auto + accum( + cute::tuple pipelines, + cute::tuple consumer_states, + cute::Tensor const& accumulators, + cute::tuple const& transform_inputs, + CtaTileCoord cta_tile_coord, + CopyOpT2R, + EpilogueTile, + int k_tile_count) { + + static_assert(size<0>(EpilogueTile{}) <= size<0>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + static_assert(size<1>(EpilogueTile{}) <= size<1>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + + + // + // PIPELINED Transform + // + + Tensor acc = slice_accumulator(accumulators, _0{}); + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + auto [sSFA_, sSFB_] = transform_inputs; + + // Append N with a stride of 0 to SFA + Tensor sSFA = make_tensor(sSFA_.data(), make_layout( + make_shape(get<0>(sSFA_.shape()), get<1>(CtaShape_MNK{}), get<1>(sSFA_.shape()), get<2>(sSFA_.shape())), + make_stride(get<0>(sSFA_.stride()), _0{}, get<1>(sSFA_.stride()), get<2>(sSFA_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFA) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFA) == size<1>(tAcc)); + + Tensor sSFA_epi = flat_divide(sSFA, EpilogueTile{}); + + // Append M with a stride of 0 to SFB + Tensor sSFB = make_tensor(sSFB_.data(), make_layout( + make_shape(get<0>(CtaShape_MNK{}), get<0>(sSFB_.shape()), get<1>(sSFB_.shape()), get<2>(sSFB_.shape())), + make_stride(_0{}, get<0>(sSFB_.stride()), get<1>(sSFB_.stride()), get<2>(sSFB_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFB) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFB) == size<1>(tAcc)); + + Tensor sSFB_epi = flat_divide(sSFB, EpilogueTile{}); + + TiledCopy tiled_t2r_epi = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + + int thread_idx = threadIdx.x % size(tiled_t2r_epi); + + ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); + + Tensor acc_ident_epi = make_identity_tensor(shape(tAcc_epi)); + + Tensor tTR_rAcc_epi = thread_t2r_epi.partition_D(acc_ident_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + Tensor tTR_sSFA_epi = thread_t2r_epi.partition_D(sSFA_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + Tensor tTR_sSFB_epi = thread_t2r_epi.partition_D(sSFB_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + static_assert(rank(decltype(tTR_sSFA_epi){}) == 7); + + Tensor tTR_FullAcc = make_tensor(shape(tTR_rAcc_epi)); + Tensor tTR_PartAcc = make_tensor(shape(tTR_rAcc_epi(_,_,_,_0{},_0{}))); + + Tensor tTR_rSFA_compact = make_fragment_like(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,_0{}))); + Tensor tTR_rSFB_compact = make_fragment_like(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,_0{}))); + + Layout tTR_rSFA_layout = make_layout(tTR_sSFA_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFA_compact.stride()); + Layout tTR_rSFB_layout = make_layout(tTR_sSFB_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFB_compact.stride()); + + // Zero our accumulator + clear(tTR_FullAcc); + + auto [accumulator_pipeline, mainloop_sf_pipeline] = pipelines; + auto [accumulator_pipe_state, mainloop_sf_pipe_state] = consumer_states; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_sf_pipeline.consumer_wait(mainloop_sf_pipe_state); + int read_idx = mainloop_sf_pipe_state.index(); + + copy(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,read_idx)), tTR_rSFA_compact); + copy(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,read_idx)), tTR_rSFB_compact); + + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFA_layout) == size(tTR_rSFA_compact)); + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFB_layout) == size(tTR_rSFB_compact)); + + Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); + Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); + + mainloop_sf_pipeline.consumer_release(mainloop_sf_pipe_state); + ++mainloop_sf_pipe_state; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < ScaleKsPerTile; ++k_block) { + + accumulator_pipeline.consumer_wait(accumulator_pipe_state); + + Tensor acc = slice_accumulator(accumulators, accumulator_pipe_state.index()); + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + Tensor tTR_tAcc = thread_t2r_epi.partition_S(tAcc_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(tAcc_epi); ++epi_m) { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(tAcc_epi); ++epi_n) { + + auto scale_a = tTR_rSFA(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + auto scale_b = tTR_rSFB(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + + Tensor full_acc = tTR_FullAcc(_,_,_,epi_m,epi_n); + // Compute tmem load predication if necessary + copy(tiled_t2r_epi, tTR_tAcc(_,_,_,epi_m,epi_n), tTR_PartAcc); + cutlass::arch::fence_view_async_tmem_load(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(full_acc); ++i) { + ElementAccumulator scale = scale_a(i) * scale_b(i); + full_acc(i) += scale * tTR_PartAcc(i); + } + } + } + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_state); + // release acc + ++accumulator_pipe_state; + } + + --k_tile_count; + } + + return cute::make_tuple(tTR_FullAcc, tiled_t2r_epi, cute::make_tuple(accumulator_pipe_state, mainloop_sf_pipe_state)); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +private: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a90566d721f6d18cdca2f3575687991685442b0 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp @@ -0,0 +1,1126 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for FastF32 Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int NumBandsToCompute_, + int ScalingFactor_, + int AccPromotionInterval_, + class AccumulatorCopyAtom_, + class ClusterShape, + class TileShape_, + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + float, + StrideA_, + float, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = float; + using PackedElementA = float2; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementAMma = typename TiledMma::ValTypeA; + using PackedElementAMma = uint32_t; + using ElementB = float; + using PackedElementB = float2; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using ElementBMma = typename TiledMma::ValTypeB; + using PackedElementBMma = uint32_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::is_same_v, "Input type A should be float"); + static_assert(cute::is_same_v, "Input type B should be float"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 3; + constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; + constexpr static int ScalingFactor = DispatchPolicy::ScalingFactor; + constexpr static int AccPromotionInterval = DispatchPolicy::AccPromotionInterval; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}) / DispatchPolicy::AccPromotionInterval; + constexpr static int NumBandsMax = 5; + static_assert(NumBandsToCompute <= NumBandsMax && NumBandsToCompute >= 3, "NumBandsToCompute should be less than maximum number of bands"); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + static_assert((NumBandsToCompute == 5 || NumBandsToCompute == 4 || NumBandsToCompute == 3), + "9xBF16 with 5/4/3 Bands are supported"); + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(append(CtaShapeA_MK{}, Int{}), Int{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + append(append(CtaShapeB_NK{}, Int{}), Int{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + }; + + struct TensorStorageTransformedAinSmem { + alignas(1024) cute::ArrayEngine> smem_ACompute; + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + union TensorStorageTransformedAinTmem { + alignas(1024) cute::ArrayEngine smem_ACompute; // No smem_ACompute + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + ElementB const** ptr_B; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Tensor shapes for Ptr-Array are initialized correctly here. + auto [M,N,K,mock_L] = problem_shape.get_host_problem_shape(0); + // Batches/Groups are managed by using appropriate pointers to input matrices + mock_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + reinterpret_cast(args.ptr_B) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto [M,N,K,L] = problem_shape.get_host_problem_shape(0); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage, + int32_t const sm_count, int32_t const sm_idx) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementA and ElementB types should be the same."); + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAdA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM or TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAdACompute, + unused_tBgB, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_temp = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrB_temp = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrBCompute = make_tensor(tBsB(_,_,_,_,0).shape()); + + auto tArA_x2 = recast>(tArA); + auto tArA_temp_x2 = recast>(tArA_temp); + auto tArACompute_x2 = recast>(tArACompute); + + auto tBrB_x2 = recast>(tBrB); + auto tBrB_temp_x2 = recast>(tBrB_temp); + auto tBrBCompute_x2 = recast>(tBrBCompute); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input B matrix from SMEM + copy(AutoVectorizingCopy{}, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tBrB_x2, tBrBCompute_x2, cutlass::NumericArrayConverter::convert); + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tBrBCompute_x2, tBrB_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tBrB_x2, tBrB_temp_x2, tBrB_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tBrB_x2, tBrB_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tArA_x2, tArACompute_x2, cutlass::NumericArrayConverter::convert); + copy(dst_copy_A, tArACompute, tAdACompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tArACompute_x2, tArA_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tArA_x2, tArA_temp_x2, tArA_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tArA_x2, tArA_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + return cute::make_tuple(reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0,0).layout()); + + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + gB_nkl, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, NumComputeMtxs, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + using ZeroScaler = cute::integral_constant; + using Scaler = cute::integral_constant; + + int remaining_accum_promotions = k_tile_count * StagesPerTile; + uint32_t mma2accum_skip_wait = (remaining_accum_promotions <= 0); + auto mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block += DispatchPolicy::AccPromotionInterval, --remaining_accum_promotions) { + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state, mma2accum_flag); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + + ++mma2accum_pipeline_producer_state; + mma2accum_skip_wait = (remaining_accum_promotions <= 1); + mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + auto tCrA0 = tCrA(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrA1 = tCrA(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrA2 = tCrA(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrB1 = tCrB(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrB2 = tCrB(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + // MMA instructions Emulation + auto accumulate = UMMA::ScaleOut::Zero; + + // First set of GEMMs that we need to perform for each band are unrolled to set compile-time constant + // scaling parameter. Scaled GEMM operations are only needed for the first MMA operation of each band. + + // Band 5 + if constexpr (NumBandsToCompute == 5) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[2] + accumulate = UMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[2] + } + } + // Band 4 + if constexpr (NumBandsToCompute >= 4) { + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA1(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[1]*B[2] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[2]*B[1] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[1]*B[2] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[2]*B[1] + } + } + // Band 3 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[0] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[2] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[0] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[2] + } + // Band 2 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[1]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[1]*B[0] + } + // Band 1 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[0] + } + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + // Obtain a single accumulator + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + // Create the TMEM copy for single EpilogueTile. + // Note that EpilogueTile = CtaTile for NoSmem epilogue + auto tiled_t2r = make_tmem_copy(tmem_cp_atom, tAcc_epi(_,_,_0{},_0{})); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(tAcc_epi); + Tensor tTR_rAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Apply epilogue subtiling to bulk accumulator + // We need to tile the whole bulk_tmem allocation with EpilogueTile. + // The accumulation should be aware of the AccumulatorPipelineStages + Tensor tBulkAcc_epi = flat_divide(accumulators(make_coord(_,_),_0{},_0{}, _), EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,PIPE) + Tensor tTR_tBulkAcc = thread_t2r.partition_S(tBulkAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N,PIPE) + return cute::make_tuple(tiled_t2r, thread_t2r, tTR_tBulkAcc, tTR_rAcc, tTR_rGlobAcc); + } + + template + CUTLASS_DEVICE auto + accum(cute::tuple accum_inputs, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_consumer_state, + int k_tile_count) { + auto [tiled_t2r, thread_t2r, tTR_tBulkAcc, + tTR_rAcc, tTR_rGlobAcc] = accum_inputs; + + + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Clear the global accumulator + CUTE_UNROLL + for (int i = 0; i 0; --k_tile_count) { + // The stage is limited to a CTA tile + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block>{}); + + cutlass::arch::fence_view_async_tmem_load(); // Need a fence bw TMEM_LOAD and arrive + mma2accum_pipeline.consumer_release(mma2accum_pipeline_consumer_state); + + ++mma2accum_pipeline_consumer_state; + skip_wait = ((k_tile_count <= 1) && (k_block >= (StagesPerTile-1))); + mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); + } + } + return cute::make_tuple(mma2accum_pipeline_consumer_state, tTR_rGlobAcc); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(gA_tensormap)); + copy(recast(pB_tensormap), recast(gB_tensormap)); + } + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Bringing tensormaps to smem (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_fetch_to_smem( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) const { + Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(gA_tensormap), recast(sA_tensormap)); + copy(recast(gB_tensormap), recast(sB_tensormap)); + + cp_async_fence(); + cp_async_wait<0>(); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + int32_t next_batch, + uint32_t lane_predicate) { + if (lane_predicate) { + // Bringing tensormaps to smem + tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +protected: + + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e744ffb6c2eec59e29f2e2f2fe123a60e3df6b4e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + using DispatchPolicy = MainloopSm100UmmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster shape + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreads = size(GmemTiledCopyA{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using MmaSmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + append(LoadShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + return { + args.ptr_A, + args.dA, + args.ptr_B, + args.dB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.ptr_A), make_shape(M,K,L), params.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cA_mk = make_identity_tensor(make_shape(M,K)); + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgA_mk = local_tile(cA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), LoadSmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyA gmem_to_smem_a_tiled_copy; + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreads; + auto thr_copy_a = gmem_to_smem_a_tiled_copy.get_slice(thread_idx); + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // gmem + cgA_mk, cgB_nk, // crd + sA, sB, // smem + problem_shape_MNKL, + gmem_to_smem_a_tiled_copy, gmem_to_smem_b_tiled_copy, + thr_copy_a, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), MmaSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class CTensorA, class CTensorB, + class STensorA, class STensorB, + class ProblemShape_MNKL, + class TiledCopyA, class TiledCopyB, + class ThreadCopyA, class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + // Unpack from load_inputs + GTensorA tAgA_mkl = get<0>(load_inputs); + GTensorB tBgB_nkl = get<1>(load_inputs); + CTensorA cgA_mk = get<2>(load_inputs); + CTensorB cgB_nk = get<3>(load_inputs); + STensorA sA = get<4>(load_inputs); + STensorB sB = get<5>(load_inputs); + ProblemShape_MNKL problem_shape_MNKL = get<6>(load_inputs); + TiledCopyA gmem_to_smem_a_tiled_copy = get<7>(load_inputs); + TiledCopyB gmem_to_smem_b_tiled_copy = get<8>(load_inputs); + ThreadCopyA thr_copy_a = get<9>(load_inputs); + ThreadCopyB thr_copy_b = get<10>(load_inputs); + auto [M,N,K,L] = problem_shape_MNKL; + + // Slice out the work coord from partitioned tensors + Tensor gA_in = tAgA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgA_mk_in = cgA_mk(_, _, get<0>(cta_coord_mnkl), _); + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gA_in); + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, k_residue, 0), gA_in); + Tensor gB = domain_offset(make_coord(0, k_residue, 0), gB_in); + + Tensor cA = domain_offset(make_coord(0, k_residue, 0), cgA_mk_in); + Tensor cB = domain_offset(make_coord(0, k_residue, 0), cgB_nk_in); + + auto tAgA = thr_copy_a.partition_S(gA); + auto tAsA = thr_copy_a.partition_D(sA); + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + Tensor tAcA = thr_copy_a.partition_S(cA); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tAgA and tBgB + Tensor tAcAk = tAcA(_,_,_,*k_tile_iter); + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = elem_less(get<0>(tAcAk(0,m,0)), M); // blk_m coord < M + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // 0-th stage with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0 && k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if ( int(get<1>(tAcAk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_a_tiled_copy, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + auto mainloop_pipe_producer_state_curr = mainloop_pipe_producer_state; + ++mainloop_pipe_producer_state; + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state_curr, barrier_token); + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state_curr.index(); + + copy_if(gmem_to_smem_a_tiled_copy, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state_curr, cutlass::arch::cpasync_barrier_arrive); + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB + > + CUTLASS_DEVICE auto + mma(MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_consumer_state, + cute::tuple, cute::Tensor> const& accumulators_pair, + cute::tuple const& mma_inputs, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + int read_stage = mainloop_pipe_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + --k_tile_count; + ++mainloop_pipe_consumer_state; + } + + return mainloop_pipe_consumer_state; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c31ec335a5152032fca9a43a4d96613de260d1f3 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,758 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + return { + tma_load_a, + args.ptr_B, + args.dB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA // for input tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class KTileCount, + class GTensorPartitionedA, + class STensorA, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + KTileCount k_tiles = get<0>(load_inputs); + GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + STensorA tAsA = get<2>(load_inputs); + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fe5ee3cd31c20f2e4f504777a33d2a25fb99a1cd --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -0,0 +1,726 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + template < + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + STensorA tAsA; + STensorB tBsB; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {} + }; + + template < + class TiledMma, + class FragmentA, class FragmentB + > + struct MmaParams { + TiledMma tiled_mma; + FragmentA tCrA; + FragmentB tCrB; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_) {} + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return LoadParams{ + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b}; // multicast masks + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return MmaParams{ + tiled_mma, + tCrA, tCrB}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..047d9b98ab2c0a638304b789caac96f800992a82 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -0,0 +1,1239 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using ElementSFA = typename TiledMma::ValTypeC; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using ElementSFB = typename TiledMma::ValTypeC; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{}); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0 and ScaleGranularityM <= size<0>(TileShape{}), "Scale Granularity M must divide Tile Shape"); + + static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{}); + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0 and ScaleGranularityN <= size<1>(TileShape{}), "Scale Granularity N must divide Tile Shape"); + + static_assert(size<1, 0>(LayoutSFA{}) == size<1, 0>(LayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + + static constexpr int ScaleGranularityK = size<1, 0>(LayoutSFA{}); + static constexpr int ScaleKsPerTile = size<2>(TileShape{}) / ScaleGranularityK; + static_assert(size<2>(TileShape{}) % ScaleGranularityK == 0 and ScaleGranularityK <= size<2>(TileShape{}), "Scale Granularity K must divide Tile Shape"); + static_assert(ScaleGranularityK % size<2>(typename TiledMma::AtomShape_MNK{}) == 0, "Scale Granularity K must be divisible by MMA_K"); + + static constexpr int K_BLOCK_MMAS_PER_SCALE_K = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(LayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(LayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + static_assert(size<0>(CtaShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<1>(CtaShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<2>(CtaShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); + + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(CtaShape_MNK{})); + using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(CtaShape_MNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopySFA = cute::remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopyB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using GmemTiledCopySFB = cute::remove_cvref_t(GmemTiledCopyPairB_{}))>; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineAsync; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync< + AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + static constexpr int CopyAlignmentSFA = GmemTiledCopySFA::AtomNumVal::value * sizeof(typename GmemTiledCopySFA::ValType) / sizeof(ElementAccumulator); + static constexpr int CopyAlignmentSFB = GmemTiledCopySFB::AtomNumVal::value * sizeof(typename GmemTiledCopySFB::ValType) / sizeof(ElementAccumulator); + + static constexpr int AlignmentSFA = CopyAlignmentSFA * (GmemTiledCopySFA::AtomNumVal::value > 1 ? + (size<0,1>(LayoutSFA{}.stride()) == 1 ? ScaleGranularityM : ScaleGranularityK) : 1); + static constexpr int AlignmentSFB = CopyAlignmentSFB * (GmemTiledCopySFB::AtomNumVal::value > 1 ? + (size<0,1>(LayoutSFB{}.stride()) == 1 ? ScaleGranularityN : ScaleGranularityK) : 1); + + + // Two arrivals per thread in the warp (1 arrival and 1 arrival through cp.async.mbarrier) + static constexpr int NumMainloopSFProducerThreadEvents = 64; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + using SmemLayoutScaleA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutScaleB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + struct PipelineStorage { + alignas(16) PipelineABStorage pipeline_ab; + alignas(16) PipelineSFStorage pipeline_sf; + alignas(16) AccumulatorPipelineStorage pipeline_accum; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + template< + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB + > + struct LoadABParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + STensorA tAsA; + STensorB tBsB; + + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + + CUTLASS_DEVICE + LoadABParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {} + }; + + template< + class KTileCount, + class GTensorScaleA, class GTensorScaleB, + class IdentTensorScaleA, class IdentTensorScaleB, + class STensorScaleA, class STensorScaleB + > + struct LoadSFParams { + // for scheduler + KTileCount k_tiles; + + GTensorScaleA gSFA_mkl; + GTensorScaleB gSFB_nkl; + IdentTensorScaleA identSFA_mkl; + IdentTensorScaleB identSFB_nkl; + STensorScaleA sSFA; + STensorScaleB sSFB; + + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + CUTLASS_DEVICE + LoadSFParams ( + KTileCount k_tiles_, + GTensorScaleA gSFA_mkl_, GTensorScaleB gSFB_nkl_, + IdentTensorScaleA identSFA_mkl_, IdentTensorScaleB identSFB_nkl_, + STensorScaleA sSFA_, STensorScaleB sSFB_, + LayoutSFA layout_SFA_, LayoutSFB layout_SFB_) + : k_tiles(k_tiles_) + , gSFA_mkl(gSFA_mkl_), gSFB_nkl(gSFB_nkl_) + , identSFA_mkl(identSFA_mkl_), identSFB_nkl(identSFB_nkl_) + , sSFA(sSFA_), sSFB(sSFB_) + , layout_SFA(layout_SFA_), layout_SFB(layout_SFB_) {} + }; + + template + struct MmaParams { + TiledMma tiled_mma; + FragmentA tCrA; + FragmentB tCrB; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_) {} + }; + + template< + class STensorScaleA, class STensorScaleB + > + struct AccumTransformParams { + // for scheduler + + STensorScaleA sSFA; + STensorScaleB sSFB; + + CUTLASS_DEVICE + AccumTransformParams ( + STensorScaleA sSFA_, STensorScaleB sSFB_) + : sSFA(sSFA_), sSFB(sSFB_) {} + }; + + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + + ElementAccumulator const* ptr_SFA; + LayoutSFA layout_SFA; + ElementAccumulator const* ptr_SFB; + LayoutSFB layout_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + args.ptr_SFA, + args.layout_SFA, + args.ptr_SFB, + args.layout_SFB + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool implementable_sf = cutlass::detail::check_alignment(args.layout_SFA); + implementable_sf = implementable_sf && cutlass::detail::check_alignment(args.layout_SFB); + + if (!implementable_sf) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for Scale Factors.\n"); + } + + return implementable && implementable_sf; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return load params containing + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + LoadABParams load_params { + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + }; + return load_params; + } + + /// Set up the data needed by this collective for load. + /// Return load params containing + /// tSFAgSFA_mkl - partitioned gmem tensor for SFA + /// tSFBgSFB_nkl - partitioned gmem tensor for SFB + /// tSFAIdentSFA_mkl - partitioned identity tensor for SFA in gmem + /// tSFBIdentSFB_nkl - partitioned identity tensor for SFB in gmem + /// tSFAsSFA - partitioned smem tensor for SFA + /// tSFBsSFB - partitioned smem tensor for SFB + /// layout_SFA - layout of SFA in gmem + /// layout_SFB - layout of SFB in gmem + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (m,k,l) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (n,k,l) + + Tensor SFA_mkl_ident = make_identity_tensor(shape(mainloop_params.layout_SFA)); + + Tensor SFB_nkl_ident = make_identity_tensor(shape(mainloop_params.layout_SFB)); + + // Tile the tensors and defer the slice + Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + static_assert(rank(decltype(gSFA_mkl){}) == 5); + static_assert(rank(decltype(gSFB_nkl){}) == 5); + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) + + LoadSFParams load_params { + size<3>(gSFA_mkl), + gSFA_mkl, gSFB_nkl, // for input scale tensor values + identSFA_mkl, identSFB_nkl, // for predicating scale tensor copies + sSFA, sSFB, // for scale tensor values + mainloop_params.layout_SFA, // for predicating scale tensor copies + mainloop_params.layout_SFB // for predicating scale tensor copies + }; + return load_params; + } + + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_tensors, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA_ = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB_ = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(rank(tCrA_) == _4{}); + + auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), + get<1>(shape(tCrA_.layout())), + Int{}, + _1{}); + + auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), + get<1>(shape(tCrB_.layout())), + Int{}, + _1{}); + + Tensor tCrA = flat_divide(tCrA_, + mma_tile_shape_A)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_M,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + Tensor tCrB = flat_divide(tCrB_, + mma_tile_shape_B)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_N,MMA_K_PER_SCALE,MMA_K_REST,PIPE) + + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + MmaParams mma_params { + tiled_mma, + tCrA, tCrB + }; + return mma_params; + } + + /// Set up the data needed by this collective for transform. + template + CUTLASS_DEVICE auto + accum_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.begin()), + SmemLayoutScaleA{}); // (ScaleMsPerTile,ScakeKsPerTile,P) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.begin()), + SmemLayoutScaleB{}); // (ScaleNsPerTile,ScaleKsPerTile,P) + + + AccumTransformParams transform_params { + sSFA, sSFB // for input tensor values + }; + return transform_params; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadABParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + MainloopABPipeline mainloop_pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + LoadABParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + auto curr_mainloop_pipe_producer_state = mainloop_pipe_producer_state; + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_ab_tail( + MainloopABPipeline mainloop_pipeline, + MainloopABPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped transform + /// Load producer Perspective + template < + class LoadSFParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + LoadSFParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + gSFA_mkl, gSFB_nkl, + identSFA_mkl, identSFB_nkl, + sSFA, sSFB, + layout_SFA, layout_SFB] = load_inputs; + + // slice out the work coord from partitioned tensors + GmemTiledCopySFA scale_copy_a{}; + GmemTiledCopySFB scale_copy_b{}; + + Tensor gSFA_k_compact = filter_zeros( + gSFA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl))); // (BLK_M_CPT, BLK_K_CPT, k_cpt) + Tensor gSFB_k_compact = filter_zeros( + gSFB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl))); // (BLK_N_CPT, BLK_K_CPT, k_cpt) + + Tensor identSFA_k_compact = filter_zeros( + identSFA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)), + gSFA_k_compact.stride()); // (BLK_M_CPT, BLK_K_CPT, k_cpt) + Tensor identSFB_k_compact = filter_zeros( + identSFB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)), + gSFB_k_compact.stride()); // (BLK_N_CPT, BLK_K_CPT, k_cpt) + + Tensor sSFA_compact = filter_zeros(sSFA); // (BLK_M_CPT, BLK_K_CPT, P) + Tensor sSFB_compact = filter_zeros(sSFB); // (BLK_N_CPT, BLK_K_CPT, P) + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); + + Tensor tSFAgSFA_k_compact = thr_scale_copy_a.partition_S(gSFA_k_compact); // (CPY, BLK_M, BLK_K, k) + Tensor tSFAIdentSFA_k_compact = thr_scale_copy_a.partition_S(identSFA_k_compact); // (CPY, BLK_M, BLK_K, k) + + Tensor tSFAsSFA_compact = thr_scale_copy_a.partition_D(sSFA_compact); + + Tensor tSFBgSFB_k_compact = thr_scale_copy_b.partition_S(gSFB_k_compact); // (CPY, BLK_N, BLK_K, k) + Tensor tSFBIdentSFB_k_compact = thr_scale_copy_b.partition_S(identSFB_k_compact); // (CPY, BLK_N, BLK_K, k) + Tensor tSFBsSFB_compact = thr_scale_copy_b.partition_D(sSFB_compact); + + Tensor thr_tile_pSFA = make_fragment_like(tSFAgSFA_k_compact(_0{},_,_,_0{})); + Tensor thr_tile_pSFB = make_fragment_like(tSFBgSFB_k_compact(_0{},_,_,_0{})); + + // Issue the loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK pipe_producer_state for _writing_ + mainloop_sf_pipeline.producer_acquire(mainloop_sf_pipe_producer_state); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFA); ++i) { + Tensor tSFAIdentSFA_compact = tSFAIdentSFA_k_compact(_0{},_,_,*k_tile_iter); + thr_tile_pSFA(i) = elem_less(tSFAIdentSFA_compact(i), + shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFB); ++i) { + Tensor tSFBIdentSFB_compact = tSFBIdentSFB_k_compact(_0{},_,_,*k_tile_iter); + thr_tile_pSFB(i) = elem_less(tSFBIdentSFB_compact(i), + shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); + } + + copy_if(scale_copy_a, thr_tile_pSFA, tSFAgSFA_k_compact(_,_,_,*k_tile_iter), + tSFAsSFA_compact(_,_,_,mainloop_sf_pipe_producer_state.index())); + copy_if(scale_copy_b, thr_tile_pSFB, tSFBgSFB_k_compact(_,_,_,*k_tile_iter), + tSFBsSFB_compact(_,_,_,mainloop_sf_pipe_producer_state.index())); + mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); + + __syncwarp(); + + ++mainloop_sf_pipe_producer_state; + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_sf_tail( + MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_sf_pipeline.producer_tail(mainloop_sf_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class TmemStorage, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma( + cute::tuple pipelines, + cute::tuple pipeline_states, + TmemStorage tmem_storage, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count) { + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, + accumulator_pipeline] = pipelines; + + auto [mainloop_pipe_consumer_state, + accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + CUTLASS_PRAGMA_UNROLL + for (int scale_k_iter = 0; scale_k_iter < size<3>(tCrA); ++scale_k_iter) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + auto acc = get<0>(slice_accumulator(tmem_storage, accumulator_pipe_producer_state.index())); + static_assert(is_tmem>::value, "Accumulator must be tmem resident."); + static_assert(rank(remove_cvref_t{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + // for each set of scale_k_blocks we zero the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,scale_k_iter,read_stage), + tCrB(_,_,k_block,scale_k_iter,read_stage), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + + } + + return make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state); + } + + /// Transform + template < + class AccumTransformParams, + class TmemStorage, + class CtaTileCoord, + class CopyOpT2R, + class EpilogueTile + > + CUTLASS_DEVICE auto + accum( + cute::tuple pipelines, + cute::tuple consumer_states, + TmemStorage tmem_storage, + AccumTransformParams const& transform_inputs, + CtaTileCoord cta_tile_coord, + CopyOpT2R, + EpilogueTile, + int k_tile_count) { + + static_assert(size<0>(EpilogueTile{}) <= size<0>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + static_assert(size<1>(EpilogueTile{}) <= size<1>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + + + // + // PIPELINED Transform + // + + Tensor acc = get<0>(slice_accumulator(tmem_storage, _0{})); + + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Append N with a stride of 0 to SFA + Tensor sSFA_ = transform_inputs.sSFA; + Tensor sSFA = make_tensor(sSFA_.data(), make_layout( + make_shape(get<0>(sSFA_.shape()), get<1>(CtaShape_MNK{}), get<1>(sSFA_.shape()), get<2>(sSFA_.shape())), + make_stride(get<0>(sSFA_.stride()), _0{}, get<1>(sSFA_.stride()), get<2>(sSFA_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFA) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFA) == size<1>(tAcc)); + + Tensor sSFA_epi = flat_divide(sSFA, EpilogueTile{}); + + // Append M with a stride of 0 to SFB + Tensor sSFB_ = transform_inputs.sSFB; + Tensor sSFB = make_tensor(sSFB_.data(), make_layout( + make_shape(get<0>(CtaShape_MNK{}), get<0>(sSFB_.shape()), get<1>(sSFB_.shape()), get<2>(sSFB_.shape())), + make_stride(_0{}, get<0>(sSFB_.stride()), get<1>(sSFB_.stride()), get<2>(sSFB_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFB) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFB) == size<1>(tAcc)); + + Tensor sSFB_epi = flat_divide(sSFB, EpilogueTile{}); + + TiledCopy tiled_t2r_epi = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + + int thread_idx = threadIdx.x % size(tiled_t2r_epi); + + ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); + + Tensor acc_ident_epi = make_identity_tensor(shape(tAcc_epi)); + + Tensor tTR_rAcc_epi = thread_t2r_epi.partition_D(acc_ident_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + Tensor tTR_sSFA_epi = thread_t2r_epi.partition_D(sSFA_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + Tensor tTR_sSFB_epi = thread_t2r_epi.partition_D(sSFB_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + static_assert(rank(decltype(tTR_sSFA_epi){}) == 7); + + Tensor tTR_FullAcc = make_tensor(shape(tTR_rAcc_epi)); + Tensor tTR_PartAcc = make_tensor(shape(tTR_rAcc_epi(_,_,_,_0{},_0{}))); + + Tensor tTR_rSFA_compact = make_fragment_like(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,_0{}))); + Tensor tTR_rSFB_compact = make_fragment_like(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,_0{}))); + + Layout tTR_rSFA_layout = make_layout(tTR_sSFA_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFA_compact.stride()); + Layout tTR_rSFB_layout = make_layout(tTR_sSFB_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFB_compact.stride()); + + // Zero our accumulator + clear(tTR_FullAcc); + + auto [accumulator_pipeline, mainloop_sf_pipeline] = pipelines; + auto [accumulator_pipe_state, mainloop_sf_pipe_state] = consumer_states; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_sf_pipeline.consumer_wait(mainloop_sf_pipe_state); + int read_idx = mainloop_sf_pipe_state.index(); + + copy(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,read_idx)), tTR_rSFA_compact); + copy(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,read_idx)), tTR_rSFB_compact); + + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFA_layout) == size(tTR_rSFA_compact)); + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFB_layout) == size(tTR_rSFB_compact)); + + Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); + Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); + + mainloop_sf_pipeline.consumer_release(mainloop_sf_pipe_state); + ++mainloop_sf_pipe_state; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < ScaleKsPerTile; ++k_block) { + + accumulator_pipeline.consumer_wait(accumulator_pipe_state); + + Tensor acc = get<0>(slice_accumulator(tmem_storage, accumulator_pipe_state.index())); + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + Tensor tTR_tAcc = thread_t2r_epi.partition_S(tAcc_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(tAcc_epi); ++epi_m) { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(tAcc_epi); ++epi_n) { + + auto scale_a = tTR_rSFA(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + auto scale_b = tTR_rSFB(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + + Tensor full_acc = tTR_FullAcc(_,_,_,epi_m,epi_n); + // Compute tmem load predication if necessary + copy(tiled_t2r_epi, tTR_tAcc(_,_,_,epi_m,epi_n), tTR_PartAcc); + cutlass::arch::fence_view_async_tmem_load(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(full_acc); ++i) { + ElementAccumulator scale = scale_a(i) * scale_b(i); + full_acc(i) += scale * tTR_PartAcc(i); + } + } + } + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_state); + // release acc + ++accumulator_pipe_state; + } + + --k_tile_count; + } + + return cute::make_tuple(tTR_FullAcc, tiled_t2r_epi, cute::make_tuple(accumulator_pipe_state, mainloop_sf_pipe_state)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54c3bd581a313d23d75c6b991e4373d78f670555 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp @@ -0,0 +1,1018 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +namespace detail { +template +struct CollectiveMmaEmulatedLayoutAtomType { + using InputLayoutAtom = InputLayoutAtom_; + using ComputeLayoutAtom = ComputeLayoutAtom_; +}; + +template +struct CollectiveMmaEmulatedCopyType { + using InputCopyAtom = InputCopyAtom_; + using ComputeCopyAtom = ComputeCopyAtom_; +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for FastF32 Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int NumBandsToCompute_, + int ScalingFactor_, + int AccPromotionInterval_, + class AccumulatorCopyAtom_, + class ClusterShape, + class TileShape_, + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + float, + StrideA_, + float, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = float; + using PackedElementA = float2; + using StrideA = StrideA_; + using ElementAMma = typename TiledMma::ValTypeA; + using PackedElementAMma = uint32_t; + using ElementB = float; + using PackedElementB = float2; + using StrideB = StrideB_; + using ElementBMma = typename TiledMma::ValTypeB; + using PackedElementBMma = uint32_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::is_same_v, "Input type A should be float"); + static_assert(cute::is_same_v, "Input type B should be float"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 3; + constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; + constexpr static int ScalingFactor = DispatchPolicy::ScalingFactor; + constexpr static int AccPromotionInterval = DispatchPolicy::AccPromotionInterval; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}) / DispatchPolicy::AccPromotionInterval; + constexpr static int NumBandsMax = 5; + static_assert(NumBandsToCompute <= NumBandsMax && NumBandsToCompute >= 3, "NumBandsToCompute should be less than maximum number of bands"); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + static_assert((NumBandsToCompute == 5 || NumBandsToCompute == 4 || NumBandsToCompute == 3), + "9xBF16 with 5/4/3 Bands are supported"); + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(append(CtaShapeA_MK{}, Int{}), Int{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + append(append(CtaShapeB_NK{}, Int{}), Int{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + }; + + struct TensorStorageTransformedAinSmem { + alignas(1024) cute::ArrayEngine> smem_ACompute; + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + union TensorStorageTransformedAinTmem { + alignas(1024) cute::ArrayEngine smem_ACompute; // No smem_ACompute + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b); // multicast masks + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementA and ElementB types should be the same."); + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAdA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM or TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAdACompute, + unused_tBgB, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_temp = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrB_temp = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrBCompute = make_tensor(tBsB(_,_,_,_,0).shape()); + + auto tArA_x2 = recast>(tArA); + auto tArA_temp_x2 = recast>(tArA_temp); + auto tArACompute_x2 = recast>(tArACompute); + + auto tBrB_x2 = recast>(tBrB); + auto tBrB_temp_x2 = recast>(tBrB_temp); + auto tBrBCompute_x2 = recast>(tBrBCompute); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input B matrix from SMEM + copy(AutoVectorizingCopy{}, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tBrB_x2, tBrBCompute_x2, cutlass::NumericArrayConverter::convert); + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tBrBCompute_x2, tBrB_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tBrB_x2, tBrB_temp_x2, tBrB_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tBrB_x2, tBrB_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tArA_x2, tArACompute_x2, cutlass::NumericArrayConverter::convert); + copy(dst_copy_A, tArACompute, tAdACompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tArACompute_x2, tArA_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tArA_x2, tArA_temp_x2, tArA_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tArA_x2, tArA_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + return cute::make_tuple(reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0,0).layout()); + + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + gB_nkl, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, NumComputeMtxs, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + using ZeroScaler = cute::integral_constant; + using Scaler = cute::integral_constant; + + int remaining_accum_promotions = k_tile_count * StagesPerTile; + uint32_t mma2accum_skip_wait = (remaining_accum_promotions <= 0); + auto mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block += DispatchPolicy::AccPromotionInterval, --remaining_accum_promotions) { + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state, mma2accum_flag); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + + ++mma2accum_pipeline_producer_state; + mma2accum_skip_wait = (remaining_accum_promotions <= 1); + mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + auto tCrA0 = tCrA(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrA1 = tCrA(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrA2 = tCrA(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrB1 = tCrB(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrB2 = tCrB(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + // MMA instructions Emulation + auto accumulate = UMMA::ScaleOut::Zero; + // First set of GEMMs that we need to perform for each band are unrolled to set compile-time constant + // scaling parameter. Scaled GEMM operations are only needed for the first MMA operation of each band. + + // Band 5 + if constexpr (NumBandsToCompute == 5) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[2] + accumulate = UMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[2] + } + } + // Band 4 + if constexpr (NumBandsToCompute >= 4) { + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA1(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[1]*B[2] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[2]*B[1] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[1]*B[2] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[2]*B[1] + } + } + // Band 3 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[0] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[2] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[0] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[2] + } + // Band 2 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[1]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[1]*B[0] + } + // Band 1 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[0] + } + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + // Obtain a single accumulator + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + // Create the TMEM copy for single EpilogueTile. + // Note that EpilogueTile = CtaTile for NoSmem epilogue + auto tiled_t2r = make_tmem_copy(tmem_cp_atom, tAcc_epi(_,_,_0{},_0{})); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(tAcc_epi); + Tensor tTR_rAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Apply epilogue subtiling to bulk accumulator + // We need to tile the whole bulk_tmem allocation with EpilogueTile. + // The accumulation should be aware of the AccumulatorPipelineStages + Tensor tBulkAcc_epi = flat_divide(accumulators(make_coord(_,_),_0{},_0{}, _), EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,PIPE) + Tensor tTR_tBulkAcc = thread_t2r.partition_S(tBulkAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N,PIPE) + return cute::make_tuple(tiled_t2r, thread_t2r, tTR_tBulkAcc, tTR_rAcc, tTR_rGlobAcc); + } + + template + CUTLASS_DEVICE auto + accum(cute::tuple accum_inputs, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_consumer_state, + int k_tile_count) { + auto [tiled_t2r, thread_t2r, tTR_tBulkAcc, + tTR_rAcc, tTR_rGlobAcc] = accum_inputs; + + + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Clear the global accumulator + CUTE_UNROLL + for (int i = 0; i 0; --k_tile_count) { + // The stage is limited to a CTA tile + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block>{}); + + cutlass::arch::fence_view_async_tmem_load(); // Need a fence bw TMEM_LOAD and arrive + mma2accum_pipeline.consumer_release(mma2accum_pipeline_consumer_state); + + ++mma2accum_pipeline_consumer_state; + skip_wait = ((k_tile_count <= 1) && (k_block >= (StagesPerTile-1))); + mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); + } + } + return cute::make_tuple(mma2accum_pipeline_consumer_state, tTR_rGlobAcc); + } + +protected: + + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5adc2b817e81c0f7f05a9dd1816c7990280d02f4 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -0,0 +1,1296 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" +#include "cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for Mixed Input Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape, + class TileShape_, + class ElementAOptionalTuple_, + class StridePairA_, + class ElementBOptionalTuple_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + ClusterShape>, + TileShape_, + ElementAOptionalTuple_, + StridePairA_, + ElementBOptionalTuple_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ +public: + // + // Type Aliases + // + + using ConversionMode = cutlass::detail::ConversionMode; + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + ClusterShape>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using KernelSchedule = typename DispatchPolicy::Schedule; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + using ElementAOptionalTuple = ElementAOptionalTuple_; + using ElementBOptionalTuple = ElementBOptionalTuple_; + +private: + + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + using ElementScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple_>; + using ElementScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ElementZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ElementZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutScale = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || + (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = GmemTiledCopyA_; + + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using SmemCopyAtomScale = Copy_Atom; + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemLayoutAtomACompute = cute::conditional_t; + using InternalSmemLayoutAtomBCompute = cute::conditional_t; + + using InternalInputCopyAtomA = cute::conditional_t; + using InternalInputCopyAtomB = cute::conditional_t; + using InternalComputeCopyAtomA = cute::conditional_t; + using InternalComputeCopyAtomB = cute::conditional_t; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using ArchTag = typename DispatchPolicy::ArchTag; + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, + "Compute type A should be cutlass::bfloat16_t or cutlass::half_t or cutlass::float_e4m3_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + + static constexpr int ScaleGranularityMN = size<0,0>(LayoutScale{}); + static constexpr int ScaleGranularityK = size<1,0>(LayoutScale{}); + using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig< + ScaleGranularityMN, + ScaleGranularityK>; + + using ScaleTileShape = cute::conditional_t(TileShape{}), size<2>(TileShape{}))), + decltype(make_shape(size<1>(TileShape{}), size<2>(TileShape{})))>; + + static constexpr int ScaleTileShape_MN = get<0>(ScaleTileShape{}); + + static constexpr int ScaleK = get<1>(ScaleTileShape{}) / ScaleGranularityK; + + using SmemLayoutAtomScale = decltype(ScaleConfig::smem_atom_layout_scale(ScaleTileShape{})); + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; //Maintains compatibility with input_transform kernel + + // Get the Algorithm parameters + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutScale = decltype(make_layout( + append(shape(SmemLayoutAtomScale{}), Int{}), + append(stride(SmemLayoutAtomScale{}), size(filter_zeros(SmemLayoutAtomScale{}))) + )); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Load2MmaPipelineStorage = typename Load2MmaPipeline::SharedStorage; + alignas(16) Load2MmaPipelineStorage load2mma_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct<128, _0> { + + struct TensorStorageUntransformed { + alignas(512) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + }; + + struct TensorStorageTransformedAinSmem { + // We require alignas(1024) here because the smem_ACompute may not be aligned to 1024 by default. + // We need 1024B alignment of smem_ACompute because we are using Swizzle<3,4,3> here. + // The Swizzle<3,4,3> aligns with 1024B. If we don't align the data, the compiler cannot deduce + // the base pointer of the data. + // This alignment allows us to perform the function swizzle(layout(i) * base_ptr). + alignas(1024) cute::ArrayEngine> smem_ACompute; + }; + + union TensorStorageTransformedAinTmem { + cute::ArrayEngine smem_ACompute; // No smem_ACompute + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes_A = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + Utils::compute_tma_transaction_bytes_extra_transform(); + static constexpr uint32_t TmaTransactionBytes_B = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytes_A + TmaTransactionBytes_B; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementScale const* ptr_S{nullptr}; + LayoutScale layout_S{}; + ElementZero const* ptr_Z{nullptr}; + }; + + struct TMAScaleParams { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_Scale = decltype(make_tma_atom( + GmemTiledCopyScale{}, + make_tensor(static_cast(nullptr), LayoutScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + size<2>(ClusterLayout_VMNK{})) + ); + + TMA_Scale tma_load_scale; + TMA_Scale tma_load_zero; + + }; + + struct EmptyScaleParams {}; + + // Device side kernel params + struct Params : public cute::conditional_t { + + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + + uint32_t tma_transaction_bytes{TmaTransactionBytes}; + SwappedStrideA dA{}; + SwappedStrideB dB{}; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + uint32_t tma_transaction_bytes = TmaTransactionBytes; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return { + {}, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else if constexpr (ModeHasScales) { + ElementScale const* ptr_S = args.ptr_S; + + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), args.layout_S); + typename Params::TMA_Scale tma_load_scale = make_tma_atom( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + size<2>(cluster_layout_vmnk) + ); + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + typename Params::TMAScaleParams scale_params{tma_load_scale, {}}; + return { + scale_params, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(args.ptr_Z), args.layout_S); + typename Params::TMA_Scale tma_load_zero = make_tma_atom( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + size<2>(cluster_layout_vmnk)); + + typename Params::TMAScaleParams scale_params{tma_load_scale, tma_load_zero}; + return { + scale_params, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_S = cutlass::detail::get_input_alignment_bits(); + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + bool check_aligned_A = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + bool check_aligned_B = cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + bool check_aligned_S = true; + bool check_aligned_Z = true; + bool check_mode_args = true; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + check_mode_args = check_mode_args && (args.ptr_S == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_S = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_Z = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!check_mode_args) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n"); + } + if (!check_aligned_A) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_B) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_S) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_Z) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n"); + } + + return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert); + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator, + class... Ts + > + CUTLASS_DEVICE auto + load_A( + Params const& params, + Load2TransformPipeline load2xform_pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, extra_input_partitions] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + //Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // LOCK mainloop_load2xform_pipeline_state for _writing_ + load2xform_pipeline.producer_acquire(load2xform_pipeline_state, load2xform_pipeline_flag); + + int tile_A_write_stage = load2xform_pipeline_state.index(); + + BarrierType* load2xform_tma_barrier = load2xform_pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop load2transform pipeline + ++load2xform_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // TMA load for A k_tile + copy(observed_tma_load_a_->with(*load2xform_tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,tile_A_write_stage)); + + if constexpr (ModeHasScales) { + auto tSgS_mkl = get<0>(extra_input_partitions); + auto tSgS = tSgS_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tSsS = get<1>(extra_input_partitions); + copy(params.tma_load_scale.with(*load2xform_tma_barrier, mcast_mask_a), tSgS(_,*k_tile_iter), tSsS(_,tile_A_write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ_mkl = get<2>(extra_input_partitions); + auto tZgZ = tZgZ_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tZsZ = get<3>(extra_input_partitions); + copy(params.tma_load_zero.with(*load2xform_tma_barrier, mcast_mask_a), tZgZ(_,*k_tile_iter), tZsZ(_,tile_A_write_stage)); + } + } + else { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert); + else static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + } + + + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator, + class... Ts + > + CUTLASS_DEVICE auto + load_B( + Params const& params, + Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, extra_input_partitions] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + //Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // LOCK mainloop_load2mma_pipeline_state for _writing_ + load2mma_pipeline.producer_acquire(load2mma_pipeline_state, load2mma_pipeline_flag); + + int tile_B_write_stage = load2mma_pipeline_state.index(); + + BarrierType* load2mma_tma_barrier = load2mma_pipeline.producer_get_barrier(load2mma_pipeline_state); + + // Advance mainloop load2mma pipeline + ++load2mma_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + // TMA load for B k_tile + copy(observed_tma_load_b_->with(*load2mma_tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,tile_B_write_stage)); + + ++k_tile_iter; + } + + return cute::make_tuple(load2mma_pipeline_state, k_tile_iter); + + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple()); + } + else if constexpr (ModeHasScales) { + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor mS_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gS_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + + Tensor tCgS_mkl = cta_mma.partition_A(gS_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCsS = cta_mma.partition_A(sS); + + // Project the cta_layout for tma_scale along the n-modes + auto [tSgS_mkl, tSsS] = tma_partition(params.tma_load_scale, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsS), group_modes<0,3>(tCgS_mkl)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gZ_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + Tensor tCgZ_mkl = cta_mma.partition_A(gZ_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor tCsZ = cta_mma.partition_A(sZ); + // Project the cta_layout for tma_scale along the n-modes + auto [tZgZ_mkl, tZsZ] = tma_partition(params.tma_load_zero, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsZ), group_modes<0,3>(tCgZ_mkl)); + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS, tZgZ_mkl, tZsZ)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class... Ts + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple> input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAsACompute : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM or TMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAsACompute, + partitioned_extra_info] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); //(Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest (Register) + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + constexpr int K_BLOCK_MAX = size<3>(tArA); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); // read stage + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); //write stage + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + // Copy scale/zero vector from SMEM + Utils::copy_scale_zeros_for_transform(partitioned_extra_info, load2transform_consumer_index); + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Dequantize A with scale/zero in RF + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; k_block ++){ + Utils::dequantize_A_kblock_for_transform(tArA, tArACompute, partitioned_extra_info, k_block); + } + + // Dequantized A is stored into either Smem or Tmem + copy(dst_copy_A, tArACompute, tAsACompute(_,_,_,_,transform2mma_producer_index)); + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto r2t_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0)); + auto thr_r2t_tiled_copy = r2t_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_r2t_tiled_copy.partition_S(tensor_input2x); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + auto partitioned_tensor_compute = thr_r2t_tiled_copy.partition_D(fragment_compute); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + + // Source copy is based on the source operand of TMEM_STORE copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + return cute::make_tuple(smem2reg_tiled_copy, r2t_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto r2s_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0).layout()); + + auto smem2reg_tiled_copy = make_tiled_copy_S(input_copy_atom, r2s_tiled_copy); + auto thr_r2s_tiled_copy = r2s_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_r2s_tiled_copy.partition_S(tensor_input); //(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + auto partitioned_tensor_compute = thr_r2s_tiled_copy.partition_D(tensor_compute_ind_sw);//(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + + return cute::make_tuple(smem2reg_tiled_copy, AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [src_copy_A, dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + // Partition of thread -> shared and thread -> RF + auto fragment_compute = TiledMma::make_fragment_A(sACompute); + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto r2t_tiled_copy = make_tmem_copy(ComputeCopyAtomA{}, fragment_compute(_,_,_,0)); + auto src_copy_scale = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + + auto partitioned_extra_info = Utils::partition_extra_transform_info(TiledMma{}, src_copy_scale, shared_storage); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + partitioned_extra_info); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + auto next_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + auto load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + ++next_load2mma_pipeline_consumer_state; + + + // tCrA : (MMA), MMA_M, MMA_K, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + ++mma2accum_pipeline_producer_state; + + // + // PIPELINED MAIN LOOP + // + // Clear the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2mma_pipeline.consumer_wait(curr_load2mma_pipeline_consumer_state, load2mma_flag); + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int load2mma_pipeline_consumer_state_index = curr_load2mma_pipeline_consumer_state.index(); //read_stage + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); //read_stage + + auto tCrA0 = tCrA(_,_,_,transform2mma_pipeline_consumer_state_index); + auto tCrB0 = tCrB(_,_,_,load2mma_pipeline_consumer_state_index); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block ++) { + cute::gemm(tiled_mma, tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + load2mma_pipeline.consumer_release(curr_load2mma_pipeline_consumer_state); + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_load2mma_pipeline_consumer_state = next_load2mma_pipeline_consumer_state; + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + + ++next_load2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + + return cute::make_tuple(curr_load2mma_pipeline_consumer_state, curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor tCrB = tiled_mma.make_fragment_B(sB); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + return accumulators; + } + +private: + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2d8172fb808a95f38a542d92c5b300ee5cb3921 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp @@ -0,0 +1,951 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/builders/sm1xx_sparse_config.inl" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class LayoutPairAE_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedSparse< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + LayoutPairAE_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedSparse< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + static_assert(get<0,0>(MmaShapeA_MK{}) == 128 && + (get<2>(MmaShapeA_MK{}) == 2 || get<2>(MmaShapeA_MK{}) == 4), + "This kernel only support MmaShape=128 and 2/4 kphase."); + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutPairAE = LayoutPairAE_; + using LayoutA = remove_cvref_t(LayoutPairAE{}))>; + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using LayoutE = remove_cvref_t(LayoutPairAE{}))>; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(is_sparse::value, "ElementAMma is sparse"); + static_assert(!is_sparse::value, "ElementA is not sparse"); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + // LayoutA is nested in the stride due to the sparsity. + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{})), Int>; + + using SparseConfig = cutlass::Sm1xxGemmSparseConfig, + ElementEMma>; + + // The offline permutation for the metadata. + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + // Metadata pathways + using GmemCopyAtomE = GmemTiledCopyA; + + using MainloopPipeline = cutlass::PipelineTmaSparseUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static constexpr int UtccpReuseCnt = ((size<2>(TileShape{}) / typename SparseConfig::TensorEAtomK{}) == 0) ? + typename SparseConfig::TensorEAtomK{} / size<2>(TileShape{}) : 1; + static_assert(UtccpReuseCnt == 1 || UtccpReuseCnt == 2, "UTCCP reuse count can only be either one or two"); + // (TileM, TileN, TileK) TileK is adjusted according to the reuse. + using TileShapeE = decltype(replace<2>(TileShape{}, cute::lcm(size<2>(TileShape{}), typename SparseConfig::TensorEAtomK{}))); + using MmaShapeE_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShapeE{}), size<2>(TileShapeE{})))); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomE{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomE{})) == 0, "SmemLayoutAtomE must evenly divide the tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) that one UTCCP instruction can provide + using SmemLayoutE = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomE{}, + append(MmaShapeE_MK{}, Int{}))); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static_assert(rank(SmemLayoutAtomE{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomE{})) == 0, "SmemLayoutAtomE must evenly divide tile shape."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_sparse_f8f6f4(); + + using TmaInternalElementA = cute::sparse_elem, + cutlass::tfloat32_t, + ElementAMmaRaw>>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::sparse_elem < 8, + uint8_t, + ElementAMmaRaw>>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + // Kernel Input Data Type that consider runtime dtype + using ArrayElementA = cute::conditional_t>, + ElementA>; + using ArrayElementB = cute::conditional_t>, + ElementB>; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_E; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t MetadataTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutE{})) * cute::sizeof_bits_v); + static constexpr uint32_t MainLoadTmaTransactionBytes = ABTmaTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + ETensor tCtE; + }; + + template < + class KTileCount, class KTileMetadataCount, + class GTensorPartitionedA, class GTensorPartitionedB, class GTensorPartitionedE, + class STensorA, class STensorB, class STensorE + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + KTileMetadataCount k_tiles_metadata; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + GTensorPartitionedE tEgE_nkl; + STensorA tAsA; + STensorB tBsB; + STensorE tEsE; + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + uint16_t mcast_mask_e; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, KTileMetadataCount k_tiles_metadata_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, GTensorPartitionedE tEgE_nkl_, + STensorA tAsA_, STensorB tBsB_, STensorE tEsE_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_, uint16_t mcast_mask_e_) + : k_tiles(k_tiles_), k_tiles_metadata(k_tiles_metadata_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_), tEgE_nkl(tEgE_nkl_) + , tAsA(tAsA_), tBsB(tBsB_), tEsE(tEsE_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_), mcast_mask_e(mcast_mask_e_) {} + }; + + template < + class TiledMma, + class FragmentA, class FragmentB, + class FragmentE, class ETiledCopy, class SmemFrgE, class TmemFrgE + > + struct MmaParams { + TiledMma tiled_mma; + // A + FragmentA tCrA; + // B + FragmentB tCrB; + // E + FragmentE tCtE; + ETiledCopy tiled_copy_s2t_E; + SmemFrgE thr_tCsE_s2t; + TmemFrgE thr_tCtE_s2t; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_, + FragmentE tCtE_, ETiledCopy tiled_copy_s2t_E_, + SmemFrgE thr_tCsE_s2t_, TmemFrgE thr_tCtE_s2t_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_) + , tCtE(tCtE_), tiled_copy_s2t_E(tiled_copy_s2t_E_) + , thr_tCsE_s2t(thr_tCsE_s2t_), thr_tCtE_s2t(thr_tCtE_s2t_) {} + }; + + // Host side kernel arguments + struct Arguments { + // A is A Compressed, not raw tensorA + ArrayElementA const* ptr_A{nullptr}; + LayoutA layout_a{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementE const* ptr_E{nullptr}; + LayoutE layout_e{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_E = decltype(make_tma_atom_A_sm100( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + make_tensor(recast_ptr(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,_,cute::Int<0>{}), + TileShapeE{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_E tma_load_e; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_E tma_load_e_fallback; + TMA_B tma_load_b_fallback; + LayoutA layout_a; + LayoutE layout_e; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , layout_a_(params.layout_a) + , layout_e_(params.layout_e) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_e_ = is_fallback_cluster ? ¶ms.tma_load_e_fallback : ¶ms.tma_load_e; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_e_ = ¶ms.tma_load_e; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + auto ptr_E = recast_ptr(args.ptr_E); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_E tma_load_e = make_tma_atom_A_sm100( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,_,cute::Int<0>{}), + TileShapeE{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_E tma_load_e_fallback = make_tma_atom_A_sm100( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,_,cute::Int<0>{}), + TileShapeE{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_e, + tma_load_b, + tma_load_a_fallback, + tma_load_e_fallback, + tma_load_b_fallback, + args.layout_a, + args.layout_e, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits_v; + + bool implementable = true; + // Check Alignment A + if constexpr (is_A_mn_major) { + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K/2, L), + cute::make_stride(_1{}, M, M*K/2)); + } + else { // If A is K-major + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K/2, L), + cute::make_stride(K/2, _1{}, M*K/2)); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA on tensorA\n"); + } + + // Check Alignment B + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits_v; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA on tensorB\n"); + } + + // Check for AB layout requirement + const auto layout_a_ref = SparseConfig::fill_layoutA(problem_shape_MNKL); + const auto layout_e_ref = SparseConfig::fill_layoutE(problem_shape_MNKL); + implementable = implementable && (layout_a_ref == args.layout_a); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_a mismatch\n"); + } + + implementable = implementable && (layout_e_ref == args.layout_e); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_e mismatch\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_e_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtE = make_tensor(take<0,3>(shape(SmemLayoutE{}))); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtE = tCtE; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtE.data() = tmem_base_addr + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(layout_a_.shape()); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + Tensor mE_mkl = observed_tma_load_e_->get_tma_tensor(layout_e_.shape()); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + Tensor gE_mkl = local_tile(mE_mkl, TileShapeE{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + Tensor tCgE_mkl = cta_mma.partition_A(gE_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (MMA,MMA_M,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // Project the cta_layout for tma_a along the n-modes + auto [tEgE_mkl, tEsE] = tma_partition(*observed_tma_load_e_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sE), group_modes<0,3>(tCgE_mkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_e = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + + return LoadParams{ + size<3>(gA_mkl), size<3>(gE_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tEgE_mkl, tAsA, tBsB, tEsE, // for input tensor values + mcast_mask_a, mcast_mask_b, mcast_mask_e}; // multicast masks + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A B E matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (MMA,MMA_M,MMA_K,PIPE) that one UTCCP can provide + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sE)); // PIPE + + Tensor tCtE = tmem_storage.tCtE; + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpEOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + cute::SM100_UTCCP_128dp128bit_2cta, cute::SM100_UTCCP_128dp128bit_1cta>; + auto tiled_copy_s2t_E = make_utccp_copy(UtccpEOp{}, recast(tCtE)); + + auto thr_copy_s2t_E = tiled_copy_s2t_E.get_slice(0); + Tensor thr_tCsE_s2t_ = thr_copy_s2t_E.partition_S(recast(sE)); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + Tensor thr_tCsE_s2t = get_utccp_smem_desc_tensor(thr_tCsE_s2t_); + Tensor thr_tCtE_s2t = thr_copy_s2t_E.partition_D(recast(tCtE)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + + return MmaParams{ + tiled_mma, + tCrA, tCrB, + tCtE, tiled_copy_s2t_E, thr_tCsE_s2t, thr_tCtE_s2t}; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [k_tiles, k_tiles_metadata, + tAgA_mkl, tBgB_nkl, tEgE_mkl, tAsA, tBsB, tEsE, + mcast_mask_a, mcast_mask_b, mcast_mask_e] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tEgE = tEgE_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + uint32_t iter = 0; + + // K_tile_iter for E + auto k_tile_start = cute::crd2idx(k_tile_iter.coord, k_tiles); + auto k_utccp_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start / UtccpReuseCnt, k_tiles_metadata), k_tiles_metadata); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + bool load_e = iter % UtccpReuseCnt == 0; + + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, load_e, barrier_token); + // Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline + // to do the synchronization at once. + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + + if (load_e) { + if (cute::elect_one_sync()) { + copy(observed_tma_load_e_->with(*tma_barrier, mcast_mask_e), tEgE(_,*k_utccp_tile_iter), tEsE(_,write_stage)); + } + ++k_utccp_tile_iter; + } + + ++k_tile_iter; + --k_tile_count; + iter++; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, + tCrA, tCrB, + tCtE, tiled_copy_s2t_E, thr_tCsE_s2t, thr_tCtE_s2t ] = mma_inputs; + + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + uint32_t math_mma_e_stage_idx = 0; + uint32_t iter = 0; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + if constexpr (not IsOverlappingAccum) { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + if constexpr (UtccpReuseCnt == 1) { + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); + } + } + else { + if (not (iter & 1)) { + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_E, thr_tCsE_s2t(_,_,_,_,read_stage), thr_tCtE_s2t); + } + } + } + + if constexpr (IsOverlappingAccum) { + if (iter == 0) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tCtE(_,_,math_mma_e_stage_idx * UtccpReuseCnt + k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + if constexpr (UtccpReuseCnt != 1) { + // Each E Smem Stage contain two CtaK's Metadata when UtccpReuse + math_mma_e_stage_idx ^= 1; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + ++iter; + } + + return mainloop_pipe_consumer_state; + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_E const* observed_tma_load_e_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + LayoutA layout_a_; + LayoutE layout_e_; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e90d727826e51ddecb4c6e1c33eaf230f9220d11 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp @@ -0,0 +1,1685 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int LoadABPipelineStageCount, + int LoadSFPipelineStageCount, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, int) + cutlass::sm103::detail::KernelPrefetchType PrefetchType, + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>; + + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // Assert that TiledMma and TileShape should be weakly compatible + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TiledMma and TileShape should be weakly compatible"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::round_up(size<1>(CtaShape_MNK{}), Blk_MN{}); + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + static int constexpr SF_BUFFERS_PER_TILE_K = SFVecSize == 16 ? 4 : 2; + using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/SF_BUFFERS_PER_TILE_K>{})); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadABPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadSFPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + static_assert(cute::is_void_v, + "SM103 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(cute::is_void_v, + "SM103 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,NUM_PIPES) + using SmemLayoutA_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,3) + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,NUM_PIPES) + using SmemLayoutB_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,3) + + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = uint8_t; + using TmaInternalElementB = uint8_t; + + using SmemAllocTypeA = uint8_t; + using SmemAllocTypeB = uint8_t; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + using SmemPrefetchType = uint8_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFA; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + struct PipelineStorage { + PipelineABStorage pipeline_ab; + PipelineSFStorage pipeline_sf; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementSF const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom( + GmemTiledCopyA{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{})), + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(ClusterShape{})) + ); + + using TMA_B = decltype(make_tma_atom( + GmemTiledCopyB{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{})), + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(ClusterShape{})/size(typename TiledMma::AtomThrID{})) + ); + + using TMA_SFA = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(ClusterShape{})) + ); + + using TMA_SFB = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(ClusterShape{})/size(typename TiledMMA_SF::AtomThrID{})) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + ElementSF const** ptr_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shapes, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + InternalLayoutSFA layout_SFA; + InternalLayoutSFB layout_SFB; + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + layout_SFA = args.layout_SFA; + layout_SFB = args.layout_SFB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = recast(make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a))); + Tensor tensor_b = recast(make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b))); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = nullptr; + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape) + ); + + typename Params::TMA_B tma_load_b = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape)/size(typename TiledMma::AtomThrID{}) + ); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape_fallback) + ); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape_fallback)/size(typename TiledMma::AtomThrID{}) + ); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape) + ); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape)/size(typename TiledMMA_SF::AtomThrID{}) + ); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape_fallback) + ); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape_fallback)/size(typename TiledMMA_SF::AtomThrID{}) + ); + + #if 0 + print("tma_load_a:\n"); + print(tma_load_a); + print("tma_load_a.tma_desc:\n"); print(tma_load_a.tma_desc_); print("\n"); + + print("tma_load_b:\n"); + print(tma_load_b); + print("tma_load_b.tma_desc:\n"); print(tma_load_b.tma_desc_); print("\n"); + + print("layout_SFA: "); print(args.layout_SFA); print("\n"); + print("tma_load_sfa:\n"); + print(tma_load_sfa); + print("tma_load_sfa.tma_desc:\n"); print(tma_load_sfa.tma_desc_); print("\n"); + + print("layout_SFB: "); print(args.layout_SFB); print("\n"); + print("tma_load_sfb:\n"); + print(tma_load_sfb); + print("tma_load_sfb.tma_desc:\n"); print(tma_load_sfb.tma_desc_); print("\n"); + + print("layout_sfa: "); print(args.layout_SFA); print("\n"); + print("tma_load_sfa_fallback:\n"); + print(tma_load_sfa_fallback); + print("tma_load_sfa_fallback.tma_desc:\n"); print(tma_load_sfa_fallback.tma_desc_); print("\n"); + + print("layout_sfb: "); print(args.layout_SFB); print("\n"); + print("tma_load_sfb_fallback:\n"); + print(tma_load_sfb_fallback); + print("tma_load_sfb_fallback.tma_desc:\n"); print(tma_load_sfb_fallback.tma_desc_); print("\n"); + #endif + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + reinterpret_cast(args.ptr_SFB) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 16 || + (SFVecSize == 32 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=32) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE auto + get_mkl_shape_tensor ( + ProblemShape_MNKL const& problem_shape_MNKL) const { + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,mock_L)); + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); + return gA_mkl; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K_recast,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl_tmp = cta_mma.partition_A(gA_mkl); // ((CTA_MMA_M,96),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor cta_tCgA = make_tensor(tCgA_mkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgA_mkl_tmp), cute::layout<1>(tCgA_mkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgA_mkl_tmp), cute::layout<2>(tCgA_mkl_tmp))), + cute::layout<3>(tCgA_mkl_tmp), cute::layout<4>(tCgA_mkl_tmp), cute::layout<5>(tCgA_mkl_tmp))); // (CTA_M,CTA_K,m,k,l) + + Tensor tCgA_mkl = make_tensor(cta_tCgA.data(), tiled_divide(cta_tCgA.layout(), + make_tile(size<1,0>(typename TiledMma::ALayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + + Tensor tCgB_nkl_tmp = cta_mma.partition_B(gB_nkl); // ((MMA_ATOM_M,96),Rest_MMA_M,Rest_MMA_K, n, k, l) + Tensor cta_tCgB = make_tensor(tCgB_nkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgB_nkl_tmp), cute::layout<1>(tCgB_nkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgB_nkl_tmp), cute::layout<2>(tCgB_nkl_tmp))), + cute::layout<3>(tCgB_nkl_tmp), cute::layout<4>(tCgB_nkl_tmp), cute::layout<5>(tCgB_nkl_tmp))); // (CTA_M,CTA_K,m,k,l) + Tensor tCgB_nkl = make_tensor(cta_tCgB.data(), tiled_divide(cta_tCgB.layout(), + make_tile(size<1,0>(typename TiledMma::BLayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M, Rest_MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_N,32),Rest_MMA_N,8,NUM_PIPE) + + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,1>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,1>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init_ab(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + if constexpr (IsGroupedGemmKernel) { + layout_SFA = params.layout_SFA[init_group]; + layout_SFB = params.layout_SFB[init_group]; + } + else { + layout_SFA = params.layout_SFA; + layout_SFB = params.layout_SFB; + } + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + // Partition for this CTA + Tensor gSFA_mkl = local_tile(mSFA_mkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + Tensor tCgSFA_mkl = make_tensor(gSFA_mkl.data(), tiled_divide(gSFA_mkl.layout(), make_tile(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_M,MMA_K),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor tCgSFB_nkl = make_tensor(gSFB_nkl.data(), tiled_divide(gSFB_nkl.layout(), make_tile(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_N,MMA_K),Rest_MMA_N,Rest_MMA_K, n, k, l) + + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(tCsSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + auto input_tensormaps = tensormaps_init_sf(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_sfa, mcast_mask_sfb, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = make_tensor(sA);; + Tensor tCrB = make_tensor(sB);; + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(take<0,3>(shape(SmemLayoutAtomSFA{}))); + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(take<0,3>(shape(SmemLayoutAtomSFB{}))); + + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tCtSFA_compact_copy = make_tensor(tCtSFA_compact.data(), append<3>(tCtSFA_compact(_,_0{},_0{}).layout())); + auto tCtSFB_compact_copy = make_tensor(tCtSFB_compact.data(), append<3>(tCtSFB_compact(_,_0{},_0{}).layout())); + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact_copy); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact_copy); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + // using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/2>{})); // 128x128x384 + // MMA shapes are ((_128,_96),_1,_8) which makes the MMA_SFA_Shape ((128, (16,3)), 1, 8/3) + // The number is not divisible by 4 in K dimension which is needed for TMEM allocation. + // To be able to iterate thru the SFs for MMA, we model this as (MMA), MMA_M, MMA_K: ((128, (16,1)), 1, 24) + // with this layout we can iterate thru the SFs by incrementing MMA_K mode by 3/6 for this example (Vs=16 vs Vs=32). + constexpr int MMA_M = size<0>(CtaShape_MNK{}); + constexpr int MMA_N_SF = CTA_N_SF; + constexpr int MMA_K_SF = shape<2>(CtaShape_MNK{}) / 2; + auto mnBasicBlockShape = make_shape(_32{}, _4{}); + auto kBasicBlockShape_single = make_shape(Int{}, Int<1>{}); + auto mma_iter_SFA_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFA_iter_shape = make_shape(mma_iter_SFA_shape, _1{}, Int{}); + auto mma_iter_SFB_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFB_iter_shape = make_shape(mma_iter_SFB_shape, _1{}, Int{}); + + // Used for MMAs + using MmaIterShapeSFA = decltype(sSFA_iter_shape); // ((32,4),(SFVecSize,1), MMA_M/128, SF_MMA_K/SfVecSize + using MmaIterShapeSFB = decltype(sSFB_iter_shape); // ((32,4),(SFVecSize,1), MMA_N/128, SF_MMA_K/SfVecSize + + Tensor tCtSFA_mma = make_tensor(MmaIterShapeSFA{}); + tCtSFA_mma.data() = tCtSFA.data(); + Tensor tCtSFB_mma = make_tensor(MmaIterShapeSFB{}); + tCtSFB_mma.data() = tCtSFB.data(); + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, tCtSFA_mma, tCtSFB_mma, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + +// Helper function to handle both prefetch types + template + CUTLASS_DEVICE void issue_prefetch( + int& prefetch_k_tile_count, + int& prefetch_buf_idx, + KTileIterator& prefetch_k_tile, + TmaPrefetchFn&& tma_prefetch_fn) + { + if (prefetch_k_tile_count > 0) { + if constexpr (PrefetchType == cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch) { + tma_prefetch_fn(); + } + + prefetch_buf_idx = (prefetch_buf_idx + 1) % BuffersPerKtile; + if(prefetch_buf_idx == 0) { + ++prefetch_k_tile; + --prefetch_k_tile_count; + } + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, int prefetch_k_tile_count = 0) { + + auto tAgA_mkl = get<2>(load_inputs); + auto tBgB_nkl = get<3>(load_inputs); + auto tAsA = get<4>(load_inputs); + auto tBsB = get<5>(load_inputs); + auto mcast_mask_a = get<6>(load_inputs); + auto mcast_mask_b = get<7>(load_inputs); + auto input_tensormaps = get<8>(load_inputs); + + if (did_batch_change) { + tensormaps_fence_acquire(get<0>(input_tensormaps)); + tensormaps_fence_acquire(get<1>(input_tensormaps)); + } + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, _, _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + constexpr int BuffersPerKtile = 3; + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadABPipelineStageCount / BuffersPerKtile; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadABPipelineStageCount % BuffersPerKtile; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + // In total, we will load 3 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < BuffersPerKtile; buffer++) { + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + auto tma_copy_traits_a = observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a); + auto tma_copy_traits_b = observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b); + + if (cute::elect_one_sync()) { + copy(tma_copy_traits_a, group_modes<0,2>(tAgA(_,_,buffer,*k_tile_iter)), tAsA(_,write_stage)); + copy(tma_copy_traits_b, group_modes<0,2>(tBgB(_,_,buffer,*k_tile_iter)), tBsB(_,write_stage)); + } + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(tma_copy_traits_a, group_modes<0,2>(tAgA(_,_,prefetch_buf_idx,*prefetch_k_tile))); + prefetch(tma_copy_traits_b, group_modes<0,2>(tBgB(_,_,prefetch_buf_idx,*prefetch_k_tile))); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapSFA, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + Params const& params, + MainloopSFPipeline pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, int prefetch_k_tile_count = 0) { + + auto tAgSFA_mkl = get<0>(load_inputs); + auto tBgSFB_nkl = get<1>(load_inputs); + auto tAsSFA = get<2>(load_inputs); + auto tBsSFB = get<3>(load_inputs); + auto mcast_mask_sfa = get<4>(load_inputs); + auto mcast_mask_sfb = get<5>(load_inputs); + auto input_tensormaps_sf = get<6>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(get<0>(input_tensormaps_sf)); + tensormaps_fence_acquire(get<1>(input_tensormaps_sf)); + } + + auto barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + + using BarrierType = typename MainloopSFPipeline::ProducerBarrierType; + auto tAsSFA_compact = make_tensor(tAsSFA.data(), filter_zeros(tAsSFA.layout())); + auto tBsSFB_compact = make_tensor(tBsSFB.data(), filter_zeros(tBsSFB.layout())); + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadSFPipelineStageCount / SF_BUFFERS_PER_TILE_K; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadSFPipelineStageCount % SF_BUFFERS_PER_TILE_K; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + // In total, we will load 2 or 4 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < SF_BUFFERS_PER_TILE_K; buffer++) { + pipeline.producer_acquire(mainloop_sf_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_sf_pipe_producer_state); + + int write_stage = mainloop_sf_pipe_producer_state.index(); + ++mainloop_sf_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + auto tAgSFA_compact = make_tensor(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + auto tBgSFB_compact = make_tensor(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + + auto tma_copy_traits_sfa = observed_tma_load_sfa_->with(get<0>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfa); + auto tma_copy_traits_sfb = observed_tma_load_sfb_->with(get<1>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfb); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_sfa_->with(get<0>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfa), tAgSFA_compact, tAsSFA_compact(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<1>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfb), tBgSFB_compact, tBsSFB_compact(_,write_stage)); + } + + auto tAgSFA_compact_prefetch = make_tensor(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + auto tBgSFB_compact_prefetch = make_tensor(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(tma_copy_traits_sfa, tAgSFA_compact_prefetch); + prefetch(tma_copy_traits_sfb, tBgSFB_compact_prefetch); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + template < + class MainloopPipeline, class MainloopPipelineState + > + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class MmaFragmentSFA, class MmaFragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto pipeline_ab = get<0>(pipelines); + auto pipeline_sf = get<1>(pipelines); + auto accumulator_pipeline = get<2>(pipelines); + auto mainloop_pipe_ab_consumer_state = get<0>(pipeline_states); + auto mainloop_pipe_sf_consumer_state = get<1>(pipeline_states); + auto accumulator_pipe_producer_state = get<2>(pipeline_states); + auto tiled_mma = get<0>(mma_inputs); + auto tCrA = get<1>(mma_inputs); + auto tCrB = get<2>(mma_inputs); + auto tCtSFA = get<3>(mma_inputs); + auto tCtSFB = get<4>(mma_inputs); + auto tCtSFA_mma = get<5>(mma_inputs); + auto tCtSFB_mma = get<6>(mma_inputs); + auto tiled_copy_s2t_SFA = get<7>(mma_inputs); + auto tCsSFA_s2t = get<8>(mma_inputs); + auto tCtSFA_s2t = get<9>(mma_inputs); + auto tiled_copy_s2t_SFB = get<10>(mma_inputs); + auto tCsSFB_s2t = get<11>(mma_inputs); + auto tCtSFB_s2t = get<12>(mma_inputs); + + tCtSFB_mma = [tCtSFB_mma = tCtSFB_mma, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB_mma; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB_mma; + } + }(); + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + constexpr int sf_stride = TiledMma::SFVecSize == 16 ? 6 : 3; + auto barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + auto barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state); + constexpr int MmasPerSfBuffer = 8 / SF_BUFFERS_PER_TILE_K; + + auto sf_load_fn = [&](const int kphase, const int k_tile_count) { + if (kphase % MmasPerSfBuffer == 0) { + pipeline_sf.consumer_wait(mainloop_pipe_sf_consumer_state, barrier_token_sf); + int read_stage_sf_buffer0 = mainloop_pipe_sf_consumer_state.index(); + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, tCsSFA_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, tCsSFB_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFB_s2t); + } + auto buffer0_mainloop_pipe_sf_consumer_state = mainloop_pipe_sf_consumer_state; + ++mainloop_pipe_sf_consumer_state; + barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state, (kphase == 8 - MmasPerSfBuffer) && k_tile_count <= 1); // only skip wait for the last one. + pipeline_sf.consumer_release(buffer0_mainloop_pipe_sf_consumer_state); + } + }; + + bool is_first_iteration = true; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // MMA 0 + sf_load_fn(0, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer0 = mainloop_pipe_ab_consumer_state.index(); + auto buffer0_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + // delay the acc acquire to unblock tmem copy. + if constexpr (IsOverlappingAccum) { + if(is_first_iteration) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iteration = false; + } + }; + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,0,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,0,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + + // MMA 1 + sf_load_fn(1, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,3,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,3,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + + // MMA 2 + sf_load_fn(2, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer1 = mainloop_pipe_ab_consumer_state.index(); + auto buffer1_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,6,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,6,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer0_mainloop_pipe_ab_consumer_state); + + + // MMA 3 + sf_load_fn(3, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,1,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,1,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 4 + sf_load_fn(4, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,4,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,4,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 5 + sf_load_fn(5, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer2 = mainloop_pipe_ab_consumer_state.index(); + auto buffer2_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state, k_tile_count <= 1); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,7,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,7,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer1_mainloop_pipe_ab_consumer_state); + + // MMA 6 + sf_load_fn(6, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,2,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,2,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + // MMA 7 + sf_load_fn(7, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,5,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,5,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer2_mainloop_pipe_ab_consumer_state); + --k_tile_count; + } + return cute::make_tuple(mainloop_pipe_ab_consumer_state, mainloop_pipe_sf_consumer_state); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto + tensormaps_init_ab( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + ElementA const* ptr_A = nullptr; + Tensor tensor_a = recast(make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group])); + + ElementB const* ptr_B = nullptr; + Tensor tensor_b = recast(make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_ab_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address_ab(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties_ab(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release_ab(shared_tensormaps, input_ab_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release_ab ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_ab_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_ab_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_ab_tensormaps), shared_tensormaps.smem_tensormap_B); + + } + + // SF tensormap ops + CUTLASS_DEVICE auto + tensormaps_init_sf( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pSFA_tensormap = make_tensor(observed_tma_load_sfa_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{}); + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_sfa, tma_desc_sfb); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + mainloop_params.ptr_SFA[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_SFA = {1,1,1,1,1}; + cute::array prob_stride_SFA = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + ElementSF const* ptr_SF = nullptr; + Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]); + + Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfa_, tensor_sfa, + prob_shape_SFA, prob_stride_SFA); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_SFA) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + prob_shape_SFA, + prob_stride_SFA); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps_sf, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address_sf(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties_sf(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release_sf(shared_tensormaps, input_tensormaps_sf); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release_sf ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps_sf) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps_sf), shared_tensormaps.smem_tensormap_SFA); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps_sf), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* input_tma_desc) { + cute::tma_descriptor_fence_acquire(input_tma_desc); + } + +protected: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fefd73271556ff263f9cd836e612a454ee7ee01c --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp @@ -0,0 +1,1276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int LoadABPipelineStageCount, + int LoadSFPipelineStageCount, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, int) + cutlass::sm103::detail::KernelPrefetchType PrefetchType, + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>; + + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // Assert that TiledMma and TileShape should be weakly compatible + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TiledMma and TileShape should be weakly compatible"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::round_up(size<1>(CtaShape_MNK{}), Blk_MN{}); + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + static int constexpr SF_BUFFERS_PER_TILE_K = SFVecSize == 16 ? 4 : 2; + using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/SF_BUFFERS_PER_TILE_K>{})); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadABPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadSFPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,NUM_PIPES) + using SmemLayoutA_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,3) + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,NUM_PIPES) + using SmemLayoutB_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,3) + + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = uint8_t; + using TmaInternalElementB = uint8_t; + + using SmemAllocTypeA = uint8_t; + using SmemAllocTypeB = uint8_t; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + using SmemPrefetchType = uint8_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + struct PipelineStorage { + PipelineABStorage pipeline_ab; + PipelineSFStorage pipeline_sf; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom( + GmemTiledCopyA{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{})), + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(ClusterShape{})) + ); + + using TMA_B = decltype(make_tma_atom( + GmemTiledCopyB{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{})), + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(ClusterShape{})/size(typename TiledMma::AtomThrID{})) + ); + + using TMA_SFA = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(ClusterShape{})) + ); + + using TMA_SFB = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(ClusterShape{})/size(typename TiledMMA_SF::AtomThrID{})) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + + } + } + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = recast(make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA))); + Tensor tensor_b = recast(make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB))); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape) + ); + typename Params::TMA_B tma_load_b = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape)/size(typename TiledMma::AtomThrID{}) + ); + typename Params::TMA_A tma_load_a_fallback = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape_fallback) + ); + typename Params::TMA_B tma_load_b_fallback = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape_fallback)/size(typename TiledMma::AtomThrID{}) + ); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape) + ); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape)/size(typename TiledMMA_SF::AtomThrID{}) + ); + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape_fallback) + ); + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape_fallback)/size(typename TiledMMA_SF::AtomThrID{}) + ); + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 16 || + (SFVecSize == 32 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=32) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K_recast,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl_tmp = cta_mma.partition_A(gA_mkl); // ((CTA_MMA_M,96),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor cta_tCgA = make_tensor(tCgA_mkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgA_mkl_tmp), cute::layout<1>(tCgA_mkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgA_mkl_tmp), cute::layout<2>(tCgA_mkl_tmp))), + cute::layout<3>(tCgA_mkl_tmp), cute::layout<4>(tCgA_mkl_tmp), cute::layout<5>(tCgA_mkl_tmp))); // (CTA_M,CTA_K,m,k,l) + + Tensor tCgA_mkl = make_tensor(cta_tCgA.data(), tiled_divide(cta_tCgA.layout(), + make_tile(size<1,0>(typename TiledMma::ALayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + + Tensor tCgB_nkl_tmp = cta_mma.partition_B(gB_nkl); // ((MMA_ATOM_M,96),Rest_MMA_M,Rest_MMA_K, n, k, l) + Tensor cta_tCgB = make_tensor(tCgB_nkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgB_nkl_tmp), cute::layout<1>(tCgB_nkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgB_nkl_tmp), cute::layout<2>(tCgB_nkl_tmp))), + cute::layout<3>(tCgB_nkl_tmp), cute::layout<4>(tCgB_nkl_tmp), cute::layout<5>(tCgB_nkl_tmp))); // (CTA_M,CTA_K,m,k,l) + Tensor tCgB_nkl = make_tensor(cta_tCgB.data(), tiled_divide(cta_tCgB.layout(), + make_tile(size<1,0>(typename TiledMma::BLayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M, Rest_MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_N,32),Rest_MMA_N,8,NUM_PIPE) + + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,1>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,1>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + ); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(params.layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(params.layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(params.layout_SFB)); + } + }(); + + // Partition for this CTA + Tensor gSFA_mkl = local_tile(mSFA_mkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + Tensor tCgSFA_mkl = make_tensor(gSFA_mkl.data(), tiled_divide(gSFA_mkl.layout(), make_tile(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_M,MMA_K),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor tCgSFB_nkl = make_tensor(gSFB_nkl.data(), tiled_divide(gSFB_nkl.layout(), make_tile(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_N,MMA_K),Rest_MMA_N,Rest_MMA_K, n, k, l) + + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(tCsSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + return cute::make_tuple( + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_sfa, mcast_mask_sfb // multicast masks + ); + } + + /// Set up the data needed by this collective for mma compute. + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = make_tensor(sA);; + Tensor tCrB = make_tensor(sB);; + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(take<0,3>(shape(SmemLayoutAtomSFA{}))); + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(take<0,3>(shape(SmemLayoutAtomSFB{}))); + + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tCtSFA_compact_copy = make_tensor(tCtSFA_compact.data(), append<3>(tCtSFA_compact(_,_0{},_0{}).layout())); + auto tCtSFB_compact_copy = make_tensor(tCtSFB_compact.data(), append<3>(tCtSFB_compact(_,_0{},_0{}).layout())); + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact_copy); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact_copy); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + // using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/2>{})); // 128x128x384 + // MMA shapes are ((_128,_96),_1,_8) which makes the MMA_SFA_Shape ((128, (16,3)), 1, 8/3) + // The number is not divisible by 4 in K dimension which is needed for TMEM allocation. + // To be able to iterate thru the SFs for MMA, we model this as (MMA), MMA_M, MMA_K: ((128, (16,1)), 1, 24) + // with this layout we can iterate thru the SFs by incrementing MMA_K mode by 3/6 for this example (Vs=16 vs Vs=32). + constexpr int MMA_M = size<0>(CtaShape_MNK{}); + constexpr int MMA_N_SF = CTA_N_SF; + constexpr int MMA_K_SF = shape<2>(CtaShape_MNK{}) / 2; + auto mnBasicBlockShape = make_shape(_32{}, _4{}); + auto kBasicBlockShape_single = make_shape(Int{}, Int<1>{}); + auto mma_iter_SFA_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFA_iter_shape = make_shape(mma_iter_SFA_shape, _1{}, Int{}); + auto mma_iter_SFB_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFB_iter_shape = make_shape(mma_iter_SFB_shape, _1{}, Int{}); + + // Used for MMAs + using MmaIterShapeSFA = decltype(sSFA_iter_shape); // ((32,4),(SFVecSize,1), MMA_M/128, SF_MMA_K/SfVecSize + using MmaIterShapeSFB = decltype(sSFB_iter_shape); // ((32,4),(SFVecSize,1), MMA_N/128, SF_MMA_K/SfVecSize + + Tensor tCtSFA_mma = make_tensor(MmaIterShapeSFA{}); + tCtSFA_mma.data() = tCtSFA.data(); + Tensor tCtSFB_mma = make_tensor(MmaIterShapeSFB{}); + tCtSFB_mma.data() = tCtSFB.data(); + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, tCtSFA_mma, tCtSFB_mma, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + +// Helper function to handle both prefetch types + template + CUTLASS_DEVICE void issue_prefetch( + int& prefetch_k_tile_count, + int& prefetch_buf_idx, + KTileIterator& prefetch_k_tile, + TmaPrefetchFn&& tma_prefetch_fn + ) + { + if (prefetch_k_tile_count > 0) { + if constexpr (PrefetchType == cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch) { + tma_prefetch_fn(); + } + prefetch_buf_idx = (prefetch_buf_idx + 1) % BuffersPerKtile; + if(prefetch_buf_idx == 0) { + ++prefetch_k_tile; + --prefetch_k_tile_count; + } + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + int prefetch_k_tile_count = 0) { + + auto tAgA_mkl = get<2>(load_inputs); + auto tBgB_nkl = get<3>(load_inputs); + auto tAsA = get<4>(load_inputs); + auto tBsB = get<5>(load_inputs); + auto mcast_mask_a = get<6>(load_inputs); + auto mcast_mask_b = get<7>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, _, _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + constexpr int BuffersPerKtile = 3; + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadABPipelineStageCount / BuffersPerKtile; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadABPipelineStageCount % BuffersPerKtile; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + // In total, we will load 3 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < BuffersPerKtile; buffer++) { + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), group_modes<0,2>(tAgA(_,_,buffer,*k_tile_iter)), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), group_modes<0,2>(tBgB(_,_,buffer,*k_tile_iter)), tBsB(_,write_stage)); + } + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(*observed_tma_load_a_, group_modes<0,2>(tAgA(_,_,prefetch_buf_idx,*prefetch_k_tile))); + prefetch(*observed_tma_load_b_, group_modes<0,2>(tBgB(_,_,prefetch_buf_idx,*prefetch_k_tile))); + } + ); + } + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + Params const& params, + MainloopSFPipeline pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + int prefetch_k_tile_count = 0) { + + auto tAgSFA_mkl = get<0>(load_inputs); + auto tBgSFB_nkl = get<1>(load_inputs); + auto tAsSFA = get<2>(load_inputs); + auto tBsSFB = get<3>(load_inputs); + auto mcast_mask_sfa = get<4>(load_inputs); + auto mcast_mask_sfb = get<5>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + + using BarrierType = typename MainloopSFPipeline::ProducerBarrierType; + auto tAsSFA_compact = make_tensor(tAsSFA.data(), filter_zeros(tAsSFA.layout())); + auto tBsSFB_compact = make_tensor(tBsSFB.data(), filter_zeros(tBsSFB.layout())); + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadSFPipelineStageCount / SF_BUFFERS_PER_TILE_K; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadSFPipelineStageCount % SF_BUFFERS_PER_TILE_K; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + // In total, we will load 2 or 4 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < SF_BUFFERS_PER_TILE_K; buffer++) { + pipeline.producer_acquire(mainloop_sf_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_sf_pipe_producer_state); + + int write_stage = mainloop_sf_pipe_producer_state.index(); + ++mainloop_sf_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + auto tAgSFA_compact = make_tensor(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + auto tBgSFB_compact = make_tensor(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA_compact, tAsSFA_compact(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB_compact, tBsSFB_compact(_,write_stage)); + } + #if 0 + if(threadIdx.x == 256 && blockIdx.x == 1 && blockIdx.y == 0) { + print("tAgSFA_compact: "); print(tAgSFA_compact); print("\n"); + print("tBgSFB_compact: "); print(tBgSFB_compact); print("\n"); + } + #endif + + auto tAgSFA_compact_prefetch = make_tensor(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + auto tBgSFB_compact_prefetch = make_tensor(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(*observed_tma_load_sfa_, tAgSFA_compact_prefetch); + prefetch(*observed_tma_load_sfb_, tBgSFB_compact_prefetch); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + template < + class MainloopPipeline, class MainloopPipelineState + > + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class MmaFragmentSFA, class MmaFragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto pipeline_ab = get<0>(pipelines); + auto pipeline_sf = get<1>(pipelines); + auto accumulator_pipeline = get<2>(pipelines); + auto mainloop_pipe_ab_consumer_state = get<0>(pipeline_states); + auto mainloop_pipe_sf_consumer_state = get<1>(pipeline_states); + auto accumulator_pipe_producer_state = get<2>(pipeline_states); + auto tiled_mma = get<0>(mma_inputs); + auto tCrA = get<1>(mma_inputs); + auto tCrB = get<2>(mma_inputs); + auto tCtSFA = get<3>(mma_inputs); + auto tCtSFB = get<4>(mma_inputs); + auto tCtSFA_mma = get<5>(mma_inputs); + auto tCtSFB_mma = get<6>(mma_inputs); + auto tiled_copy_s2t_SFA = get<7>(mma_inputs); + auto tCsSFA_s2t = get<8>(mma_inputs); + auto tCtSFA_s2t = get<9>(mma_inputs); + auto tiled_copy_s2t_SFB = get<10>(mma_inputs); + auto tCsSFB_s2t = get<11>(mma_inputs); + auto tCtSFB_s2t = get<12>(mma_inputs); + + tCtSFB_mma = [tCtSFB_mma = tCtSFB_mma, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB_mma; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB_mma; + } + }(); + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + constexpr int sf_stride = TiledMma::SFVecSize == 16 ? 6 : 3; + auto barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + auto barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state); + constexpr int MmasPerSfBuffer = 8 / SF_BUFFERS_PER_TILE_K; + + auto sf_load_fn = [&](const int kphase, const int k_tile_count) { + if (kphase % MmasPerSfBuffer == 0) { + pipeline_sf.consumer_wait(mainloop_pipe_sf_consumer_state, barrier_token_sf); + int read_stage_sf_buffer0 = mainloop_pipe_sf_consumer_state.index(); + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, tCsSFA_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, tCsSFB_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFB_s2t); + } + auto buffer0_mainloop_pipe_sf_consumer_state = mainloop_pipe_sf_consumer_state; + ++mainloop_pipe_sf_consumer_state; + barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state, (kphase == 8 - MmasPerSfBuffer) && k_tile_count <= 1); // only skip wait for the last one. + pipeline_sf.consumer_release(buffer0_mainloop_pipe_sf_consumer_state); + } + }; + + bool is_first_iteration = true; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // MMA 0 + sf_load_fn(0, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer0 = mainloop_pipe_ab_consumer_state.index(); + auto buffer0_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + // delay the acc acquire to unblock tmem copy. + if constexpr (IsOverlappingAccum) { + if(is_first_iteration) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iteration = false; + } + }; + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,0,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,0,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + + // MMA 1 + sf_load_fn(1, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,3,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,3,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + + // MMA 2 + sf_load_fn(2, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer1 = mainloop_pipe_ab_consumer_state.index(); + auto buffer1_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,6,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,6,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer0_mainloop_pipe_ab_consumer_state); + + + // MMA 3 + sf_load_fn(3, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,1,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,1,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 4 + sf_load_fn(4, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,4,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,4,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 5 + sf_load_fn(5, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer2 = mainloop_pipe_ab_consumer_state.index(); + auto buffer2_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state, k_tile_count <= 1); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,7,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,7,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer1_mainloop_pipe_ab_consumer_state); + + // MMA 6 + sf_load_fn(6, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,2,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,2,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + // MMA 7 + sf_load_fn(7, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,5,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,5,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer2_mainloop_pipe_ab_consumer_state); + --k_tile_count; + } + return cute::make_tuple(mainloop_pipe_ab_consumer_state, mainloop_pipe_sf_consumer_state); + } + +protected: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d0f5a1524256b695618c06f5e9e58e94ace3d21 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp @@ -0,0 +1,1163 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomsA_, + class SmemCopyAtomsA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomsB_, + class SmemCopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120ArrayTmaWarpSpecializedBlockScaled, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomsA_, + SmemCopyAtomsA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomsB_, + SmemCopyAtomsB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120ArrayTmaWarpSpecializedBlockScaled; + using TileShape = TileShape_; + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using InternalLayoutSFB = cute::remove_pointer_t; + + + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + + static constexpr int SFVecSize = TiledMma::Traits::SFVecSize; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + // Gmem copies + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + // Smem copies + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomsA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomsA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomsB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomsB{}))>; + + using SmemCopyAtomsA = SmemCopyAtomsA_; + using SmemCopyAtomsB = SmemCopyAtomsB_; + + using SmemCopyAtomA = remove_cvref_t(SmemCopyAtomsA{}))>; + using SmemCopyAtomSFA = remove_cvref_t(SmemCopyAtomsA{}))>; + + using SmemCopyAtomB = remove_cvref_t(SmemCopyAtomsB{}))>; + using SmemCopyAtomSFB = remove_cvref_t(SmemCopyAtomsB{}))>; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using TmaInternalElementB = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFB{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFA; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementSF const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + using TMA_SFA = decltype(make_tma_copy( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + + using TMA_SFB = decltype(make_tma_copy( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shapes, Arguments const& args, void* workspace) { + (void) workspace; + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Batches/Groups are managed by using appropriate pointers to input matrices + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + ElementSF const* ptr_SFA_first_batch = nullptr; + ElementSF const* ptr_SFB_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + InternalLayoutSFA layout_SFA; + InternalLayoutSFB layout_SFB; + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + layout_SFA = args.layout_SFA; + layout_SFB = args.layout_SFB; + } + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFA tma_load_sfa = make_tma_copy( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFB tma_load_sfb = make_tma_copy( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + args.layout_SFA, + reinterpret_cast(args.ptr_SFB), + args.layout_SFB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + // Temporary adhoc partitioning for scaling factors. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFA(SFATensor&& sfatensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfatensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFA_TV = typename Atom::Traits::SFALayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfatensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFB(SFBTensor&& sfbtensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfbtensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfbtensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfatensor).data(), thrfrg_SFA(sfatensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFA = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFA); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFB(SFBTensor&& sfbtensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfbtensor).data(), thrfrg_SFB(sfbtensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFB = thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFB); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFA_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_A = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFA(ref_A, mma).compose(atile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFB_TV(TiledMma& mma) + { + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB(ref_B, mma).compose(btile, _).compose(thridx_2_thrid, _); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + const int32_t init_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.tma_load_a.get_tma_tensor(make_shape(M,K,init_L)); // (m,k,l) + Tensor mB_nkl = params.tma_load_b.get_tma_tensor(make_shape(N,K,init_L)); // (n,k,l) + + // Represent the full tensor of Scale factors + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + if constexpr (IsGroupedGemmKernel) { + layout_SFA = params.layout_SFA[0]; + layout_SFB = params.layout_SFB[0]; + } + else { + layout_SFA = params.layout_SFA; + layout_SFB = params.layout_SFB; + } + + Tensor mSFA_mkl = params.tma_load_sfa.get_tma_tensor(shape(layout_SFA)); + Tensor mSFB_nkl = params.tma_load_sfb.get_tma_tensor(shape(layout_SFB)); + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class TensorMapA, class TensorMapB, + class TensorMapSFA, class TensorMapSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B, SFA and SFB + // + + auto [gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl] = load_inputs; + + auto block_tma_a = params.tma_load_a.get_slice(0); + auto block_tma_b = params.tma_load_b.get_slice(0); + + auto block_tma_sfa = params.tma_load_sfa.get_slice(0); + auto block_tma_sfb = params.tma_load_sfb.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gSFB = gSFB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Partition source and destination tensors for tma copies + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAgSFA = block_tma_sfa.partition_S(gSFA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsSFA = block_tma_sfa.partition_D(sSFA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgSFB = block_tma_sfb.partition_S(gSFB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsSFB = block_tma_sfb.partition_D(sSFB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(params.tma_load_a.with(get<0>(input_tensormaps),*tma_barrier), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(params.tma_load_b.with(get<1>(input_tensormaps),*tma_barrier), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + copy(params.tma_load_sfa.with(get<2>(input_tensormaps),*tma_barrier), tAgSFA(_,_,_,*k_tile_iter), tAsSFA(_,_,_,write_stage)); + copy(params.tma_load_sfb.with(get<3>(input_tensormaps),*tma_barrier), tBgSFB(_,_,_,*k_tile_iter), tBsSFB(_,_,_,write_stage)); + + // Advance k tile + ++k_tile_iter; + ++smem_pipe_write; + } + } + __syncwarp(); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + [[maybe_unused]] Params const& params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + clear(accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCrSFA = partition_fragment_SFA(sSFA(_,_,Int<0>{}), thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrSFB = partition_fragment_SFB(sSFB(_,_,Int<0>{}), thread_mma); // (MMA,MMA_N,MMA_K) + + // + // Copy from smem to registers + // + + // A + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + // B + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + // SFA + auto tile_shape_mnk = tile_shape(tiled_mma); + auto smem_tiled_copy_SFA = make_tiled_copy_impl(SmemCopyAtomSFA{}, + get_layoutSFA_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFA = smem_tiled_copy_SFA.get_thread_slice(thread_idx); + Tensor tCsSFA = smem_thr_copy_SFA.partition_S( + as_position_independent_swizzle_tensor(sSFA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrSFA_copy_view = smem_thr_copy_SFA.retile_D(tCrSFA); // (CPY,CPY_M,CPY_K) + + // SFB + auto smem_tiled_copy_SFB = make_tiled_copy_impl(SmemCopyAtomSFB{}, + get_layoutSFB_TV(tiled_mma), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFB = smem_tiled_copy_SFB.get_thread_slice(thread_idx); + Tensor tCsSFB = smem_thr_copy_SFB.partition_S( + as_position_independent_swizzle_tensor(sSFB)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrSFB_copy_view = smem_thr_copy_SFB.retile_D(tCrSFB); // (CPY,CPY_N,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sA) == size<2>(sSFA)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sB) == size<2>(sSFA)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + auto tCsSFA_stage = tCsSFA(_,_,_,read_stage); + auto tCsSFB_stage = tCsSFB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + + + // Copy smem->rmem for SFA/SFB operand + copy(tCsSFA_stage(_,_,k_block), tCrSFA_copy_view(_,_,k_block)); + copy(tCsSFB_stage(_,_,k_block), tCrSFB_copy_view(_,_,k_block)); + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block), tCrSFA(_,_,k_block)), make_zip_tensor(tCrB(_,_,k_block), tCrSFB(_,_,k_block)), accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + tCsSFA_stage = tCsSFA(_,_,_,read_stage); + tCsSFB_stage = tCsSFB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + }); + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); +} + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } + + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + Tensor pSFA_tensormap = make_tensor(mainloop_params.tma_load_sfa.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{}); + Tensor pSFB_tensormap = make_tensor(mainloop_params.tma_load_sfb.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + __syncwarp(); + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_sfa, tma_desc_sfb); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + mainloop_params.ptr_SFA[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_SFA = {1,1,1,1,1}; + cute::array prob_stride_SFA = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + ElementSF const* ptr_SF = nullptr; + Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]); + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_sfa, tensor_sfa, + prob_shape_SFA, prob_stride_SFA); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_sfb, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFA) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + prob_shape_SFA, + prob_stride_SFA); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_SFA); + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..84d1ab14caa75497b8ecd0d42cf279a4f634e51f --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp @@ -0,0 +1,887 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomsA_, + class SmemCopyAtomsA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomsB_, + class SmemCopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120TmaWarpSpecializedBlockScaled, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomsA_, + SmemCopyAtomsA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomsB_, + SmemCopyAtomsB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120TmaWarpSpecializedBlockScaled; + using TileShape = TileShape_; + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + + static constexpr int SFVecSize = TiledMma::Traits::SFVecSize; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + // Gmem copies + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + // Smem copies + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomsA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomsA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomsB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomsB{}))>; + + using SmemCopyAtomsA = SmemCopyAtomsA_; + using SmemCopyAtomsB = SmemCopyAtomsB_; + + using SmemCopyAtomA = remove_cvref_t(SmemCopyAtomsA{}))>; + using SmemCopyAtomSFA = remove_cvref_t(SmemCopyAtomsA{}))>; + + using SmemCopyAtomB = remove_cvref_t(SmemCopyAtomsB{}))>; + using SmemCopyAtomSFB = remove_cvref_t(SmemCopyAtomsB{}))>; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using TmaInternalElementB = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using TmaInternalElementSF = ElementSF; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFB{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + using TMA_SFA = decltype(make_tma_copy( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + + using TMA_SFB = decltype(make_tma_copy( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFA tma_load_sfa = make_tma_copy( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFB tma_load_sfb = make_tma_copy( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + args.layout_SFA, + args.layout_SFB, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + + // Temporary adhoc partitioning for scaling factors. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFA(SFATensor&& sfatensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfatensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFA_TV = typename Atom::Traits::SFALayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfatensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFB(SFBTensor&& sfbtensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfbtensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfbtensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfatensor).data(), thrfrg_SFA(sfatensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFA = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFA); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFB(SFBTensor&& sfbtensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfbtensor).data(), thrfrg_SFB(sfbtensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFB = thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFB); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFA_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_A = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFA(ref_A, mma).compose(atile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFB_TV(TiledMma& mma) + { + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB(ref_B, mma).compose(btile, _).compose(thridx_2_thrid, _); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mSFA_mkl = params.tma_load_sfa.get_tma_tensor(shape(params.layout_SFA)); + Tensor mSFB_nkl = params.tma_load_sfb.get_tma_tensor(shape(params.layout_SFB)); + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B, SFA and SFB + // + + auto [gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl] = load_inputs; + + auto block_tma_a = params.tma_load_a.get_slice(0); + auto block_tma_b = params.tma_load_b.get_slice(0); + + auto block_tma_sfa = params.tma_load_sfa.get_slice(0); + auto block_tma_sfb = params.tma_load_sfb.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gSFB = gSFB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Partition source and destination tensors for tma copies + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAgSFA = block_tma_sfa.partition_S(gSFA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsSFA = block_tma_sfa.partition_D(sSFA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgSFB = block_tma_sfb.partition_S(gSFB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsSFB = block_tma_sfb.partition_D(sSFB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(params.tma_load_a.with(*tma_barrier), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(params.tma_load_b.with(*tma_barrier), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + copy(params.tma_load_sfa.with(*tma_barrier), tAgSFA(_,_,_,*k_tile_iter), tAsSFA(_,_,_,write_stage)); + copy(params.tma_load_sfb.with(*tma_barrier), tBgSFB(_,_,_,*k_tile_iter), tBsSFB(_,_,_,write_stage)); + + // Advance k tile + ++k_tile_iter; + ++smem_pipe_write; + } + } + __syncwarp(); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + [[maybe_unused]] Params const& params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + clear(accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCrSFA = partition_fragment_SFA(sSFA(_,_,Int<0>{}), thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrSFB = partition_fragment_SFB(sSFB(_,_,Int<0>{}), thread_mma); // (MMA,MMA_N,MMA_K) + + // + // Copy from smem to registers + // + + // A + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + // B + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + // SFA + auto tile_shape_mnk = tile_shape(tiled_mma); + auto smem_tiled_copy_SFA = make_tiled_copy_impl(SmemCopyAtomSFA{}, + get_layoutSFA_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFA = smem_tiled_copy_SFA.get_thread_slice(thread_idx); + Tensor tCsSFA = smem_thr_copy_SFA.partition_S( + as_position_independent_swizzle_tensor(sSFA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrSFA_copy_view = smem_thr_copy_SFA.retile_D(tCrSFA); // (CPY,CPY_M,CPY_K) + + // SFB + auto smem_tiled_copy_SFB = make_tiled_copy_impl(SmemCopyAtomSFB{}, + get_layoutSFB_TV(tiled_mma), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFB = smem_tiled_copy_SFB.get_thread_slice(thread_idx); + Tensor tCsSFB = smem_thr_copy_SFB.partition_S( + as_position_independent_swizzle_tensor(sSFB)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrSFB_copy_view = smem_thr_copy_SFB.retile_D(tCrSFB); // (CPY,CPY_N,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sA) == size<2>(sSFA)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sB) == size<2>(sSFA)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + auto tCsSFA_stage = tCsSFA(_,_,_,read_stage); + auto tCsSFB_stage = tCsSFB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + + + // Copy smem->rmem for SFA/SFB operand + copy(tCsSFA_stage(_,_,k_block), tCrSFA_copy_view(_,_,k_block)); + copy(tCsSFB_stage(_,_,k_block), tCrSFB_copy_view(_,_,k_block)); + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block), tCrSFA(_,_,k_block)), make_zip_tensor(tCrB(_,_,k_block), tCrSFB(_,_,k_block)), accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + tCsSFA_stage = tCsSFA(_,_,_,read_stage); + tCsSFB_stage = tCsSFB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + }); + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); +} + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03163121718ae8e794fcb0e0ec95cd426b88b6e8 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp @@ -0,0 +1,1320 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/collective/builders/sm1xx_sparse_config.inl" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// CollectiveMma for A/B with different or same stages based on asymmetric DMA. + +template < + int StagesA, + int StagesB, + int StagesE, + int SchedulerPipelineStageCount, + class ClusterShape, + class TileShape_, + class ElementPairA_, + class LayoutPairsA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomsA_, + class SmemCopyAtomsA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomsB_, + class SmemCopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120TmaWarpSpecializedSparseBlockScaled, + TileShape_, + ElementPairA_, + LayoutPairsA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomsA_, + SmemCopyAtomsA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomsB_, + SmemCopyAtomsB_, + TransformB_> { + // + // Type Aliases + // + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using LayoutPairsA = LayoutPairsA_; + using StridePairB = StridePairB_; + using SmemCopyAtomsA = SmemCopyAtomsA_; + using SmemCopyAtomsB = SmemCopyAtomsB_; + + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm120TmaWarpSpecializedSparseBlockScaled; + using TileShape = TileShape_; + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutA = remove_cvref_t(LayoutPairsA{}))>; + using LayoutE = remove_cvref_t(LayoutPairsA{}))>; + using StrideA = remove_cvref_t(LayoutPairsA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using ElementBMma = typename TiledMma::ValTypeB; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using RegisterE = typename remove_extent::type; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + // SFA, SFB and metadata config + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, + "SFA and SFB data types should be the same"); + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(LayoutPairsA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + static constexpr int SFVecSize = TiledMma::Traits::SFVecSize; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB_{}))>;; + using SmemCopyAtomA = remove_cvref_t(SmemCopyAtomsA{}))>; + using SmemCopyAtomE = remove_cvref_t(SmemCopyAtomsA{}))>; + using SmemCopyAtomB = remove_cvref_t(SmemCopyAtomsB{}))>; + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomsA_{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomsB_{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomsA_{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomsB_{}))>; + using SmemCopyAtomSFA = remove_cvref_t(SmemCopyAtomsA{}))>; + using SmemCopyAtomSFB = remove_cvref_t(SmemCopyAtomsB{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA_{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB_{}))>; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using GmemTiledCopyE = GmemTiledCopyA; + + // Asymmetric buffering + // Tensor A/B could have different buffering, with TILEK, and STAGEs. + // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's + // pipeline keep same steps when produce / consume data. + // Currently, AsymmetricKRatio = {1, 2} is the only support. + static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; + + // Construct TileShape for SFB load from GMEM to SMEM. + // It is required to keep consistency with BlockScaled granularity defined in Sm1xxBlkScaledConfig. + // So that TileShape for scaling factor needs to be defined as a multiple of Blk_MN. + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using TileShapeSF = decltype(make_shape(ceil_div(size<0>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}, + ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}, + shape<2>(CtaShape_MNK{}))); + using TileShapeB = decltype(make_shape(size<0>(TileShape{}), + size<1>(TileShape{}), + ceil_div(size<2>(TileShape{}), Int{}))); + + static constexpr int ThreadCount = size(TiledMma{}); + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static constexpr int TensorAMmaSparsity = ElementAMma::sparsity; + static constexpr int TensorEMmaSparsity = ElementEMma::sparsity; + + // Use two MainloopPipeline for A and B separately. + using MainloopPipelineMK = cutlass::PipelineTmaAsync; + using MainloopPipelineNK = cutlass::PipelineTmaAsync; + using PipelineStateMK = typename cutlass::PipelineState; + using PipelineStateNK = typename cutlass::PipelineState; + using PipelineParams = typename MainloopPipelineMK::Params; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + // Note: SmemA, SmemSFA and SmemSFB are with same stages, while SmemB is with another stage number. + // SmemSFB is not with same stages as SmemB, as it will not design 1.5x stages if Smem not enough. + // These different stages setting could maximize capacity of latency hide, while keep data in SMEM. + // Metadata may kept in SMEM, or in GMEM/L2, if under SMEM limitation. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShapeB{}), shape<2>(TileShapeB{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::StagesA >= 2, "Specialization requires StagesA set to value 2 or more."); + static_assert(DispatchPolicy::StagesB >= 2, "Specialization requires StagesB set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_sparse_f8f6f4(); + + // Is E kept in SMEM or GMEM + static constexpr bool UseSmemE = DispatchPolicy::StagesE != 0; + + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + using TmaSourceElementA = cute::conditional_t; + + using TmaInternalElementB = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + // Set shared memory layout + using SmemAllocTypeA = cute::conditional_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t; + + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{})), Int>; + using SparseConfig = cutlass::Sm1xxGemmSparseConfig, + ElementEMma>; + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + static constexpr int SmemSizeE = UseSmemE ? cosize(SmemLayoutE{}) : 0; + static constexpr int StageSizeE = UseSmemE ? cosize(take<0,2>(SmemLayoutE{})) : 0; + // Check if metetata fetching needs predication + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + static constexpr bool IsELoadPred = not (TensorEAtomM{} == size<0>(TileShape{}) && TensorEAtomK{} == size<2>(TileShape{})); + + static_assert(rank(SmemLayoutAtomE{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomE{})) == 0, "SmemLayoutAtomE must evenly divide tile shape."); + + // Set the bytes transferred in this TMA transaction + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFB{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(StageSizeE * cute::sizeof_bits_v)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutB{})) * cute::sizeof_bits_v)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + alignas(1024) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + cute::ArrayEngine{}> smem_E; + } tensors; + + using PipelineStorageMK = typename MainloopPipelineMK::SharedStorage; + using PipelineStorageNK = typename MainloopPipelineNK::SharedStorage; + alignas(16) PipelineStorageMK pipeline_storage_mk; + alignas(16) PipelineStorageNK pipeline_storage_nk; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorageMK = typename SharedStorage::PipelineStorageMK; + using PipelineStorageNK = typename SharedStorage::PipelineStorageNK; + + struct Arguments { + ElementA const* ptr_A{nullptr}; + LayoutA layout_a{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementE const* ptr_E{nullptr}; + LayoutE layout_e{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr>(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShapeB{}), shape<2>(TileShapeB{})), + _1{})); + using TMA_E = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); + using TMA_SFA = decltype(make_tma_copy( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); + using TMA_SFB = decltype(make_tma_copy( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShapeSF{}), shape<2>(TileShapeSF{})), + _1{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_E tma_load_e; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + LayoutA layout_a; + LayoutE layout_e; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + ElementE const* ptr_E{nullptr}; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr>(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + auto ptr_E = recast_ptr(args.ptr_E); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShapeB{})), + _1{}); + typename Params::TMA_E tma_load_e = make_tma_copy( + GmemTiledCopyE{}, + tensor_e, + SmemLayoutE{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); + typename Params::TMA_SFA tma_load_sfa = make_tma_copy( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); + typename Params::TMA_SFB tma_load_sfb = make_tma_copy( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShapeSF{}), shape<2>(TileShapeSF{})), + _1{}); + return { + tma_load_a, + tma_load_b, + tma_load_e, + tma_load_sfa, + tma_load_sfb, + args.layout_a, + args.layout_e, + args.layout_SFA, + args.layout_SFB, + args.ptr_E + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::upcast<2>(make_layout(make_shape(M, K, L), StrideA{}))); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfb.get_tma_descriptor()); + if constexpr (UseSmemE) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_e.get_tma_descriptor()); + } + } + + /// Create fragment for metadata. The function is referred from thrfrg_A(...) + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_E(Tensor&& tensor, TiledMMA& mma) { + CUTE_STATIC_ASSERT_V(rank(tensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutE_TV = typename Atom::Traits::ELayout; + + auto t_tile = make_tile(get<0>(TiledPerm{}), + get<2>(TiledPerm{})); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + auto t_tensor = logical_divide(tensor, t_tile); + + // Tile the tensor for the Atom + auto e_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto e_tensor = zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = e_tensor.compose(AtomLayoutE_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + // Fragment layout + return thr_tensor; + } + + /// get metadata TV + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutE_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_E = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_E(ref_E, mma).compose(etile, _).compose(thridx_2_thrid, _); + } + + /// Partitioning for metadata. + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_E(Tensor&& tensor, ThrMma& thread_mma) { + auto thr_tensor = make_tensor(static_cast(tensor).data(), thrfrg_E(tensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition.layout()); + } + + // Temporary adhoc partitioning for scaling factors. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFA(SFATensor&& sfatensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfatensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFA_TV = typename Atom::Traits::SFALayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfatensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFB(SFBTensor&& sfbtensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfbtensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfbtensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfatensor).data(), thrfrg_SFA(sfatensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFA = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFA); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFB(SFBTensor&& sfbtensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfbtensor).data(), thrfrg_SFB(sfbtensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFB = thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFB); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFA_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_A = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFA(ref_A, mma).compose(atile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFB_TV(TiledMma& mma) + { + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB(ref_B, mma).compose(btile, _).compose(thridx_2_thrid, _); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(mainloop_params.layout_a.shape()); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mE_mkl = mainloop_params.tma_load_e.get_tma_tensor(mainloop_params.layout_e.shape()); // (m,k,l) + Tensor mSFA_mkl = mainloop_params.tma_load_sfa.get_tma_tensor(shape(mainloop_params.layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN64) { + Tensor mSFB_tmp = mainloop_params.tma_load_sfb.get_tma_tensor(shape(mainloop_params.layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), _2{}); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride(_0{}), x)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return mainloop_params.tma_load_sfb.get_tma_tensor(shape(mainloop_params.layout_SFB)); + } + }(); + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // ( BLK_M, BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShapeB{}, make_coord(_,_,_), Step< X,_1,_1>{}); // ( BLK_N, BLK_K,n,k,l) + Tensor gE_mkl = local_tile(mE_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // ( BLK_N, BLK_K,n,k,l) + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShapeSF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gE_mkl, gSFA_mkl, gSFB_nkl); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + template + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + // Issues loads for A/E/SF only (used when DMA warp is split). + template < + class TensorA, class TensorB, class TensorE, + class TensorSFA, class TensorSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& params, + MainloopPipelineMK pipeline, + PipelineStateMK smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and E + // + + Tensor gA_mkl = get<0>(load_inputs); // (BLK_M,BLK_K,k) + Tensor gE_mkl = get<2>(load_inputs); // (BLK_M,BLK_K,k) + Tensor gSFA_mkl = get<3>(load_inputs); // (BLK_M,BLK_K,k) + Tensor gSFB_nkl = get<4>(load_inputs); // (BLK_N,BLK_K,k) + + auto block_tma_a = params.tma_load_a.get_slice(0); + auto block_tma_e = params.tma_load_e.get_slice(0); + auto block_tma_sfa = params.tma_load_sfa.get_slice(0); + auto block_tma_sfb = params.tma_load_sfb.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gE = gE_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gSFB = gSFB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Partition source and destination tensors for tma copies + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K, k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tEgE = block_tma_e.partition_S(gE); // (TMA,TMA_M,TMA_K, k) + Tensor tEsE = block_tma_e.partition_D(sE); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tAgSFA = block_tma_sfa.partition_S(gSFA); // (TMA,TMA_M,TMA_K, k) + Tensor tAsSFA = block_tma_sfa.partition_D(sSFA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tBgSFB = block_tma_sfb.partition_S(gSFB); // (TMA,TMA_N,TMA_K, k) + Tensor tBsSFB = block_tma_sfb.partition_D(sSFB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + using BarrierType = typename MainloopPipelineMK::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(params.tma_load_a.with(*tma_barrier), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(params.tma_load_sfa.with(*tma_barrier), tAgSFA(_,_,_,*k_tile_iter), tAsSFA(_,_,_,write_stage)); + copy(params.tma_load_sfb.with(*tma_barrier), tBgSFB(_,_,_,*k_tile_iter), tBsSFB(_,_,_,write_stage)); + if constexpr (UseSmemE) { + copy(params.tma_load_e.with(*tma_barrier), tEgE(_,_,_,*k_tile_iter), tEsE(_,_,_,write_stage)); + } + } + + if constexpr (!UseSmemE) { + // Prefetch 1 stage of E data to L2 in advance + auto blk_coord_mkl = make_coord(get<0>(blk_coord), *k_tile_iter, get<3>(blk_coord)); // (BLK_M,BLK_K,L) + prefetch(make_local_E(params, blk_coord_mkl)); + } + + // Advance smem_pipe_write + ++k_tile_iter; + ++smem_pipe_write; + } + } + + // Issues loads for B/SF only (used when DMA warp is split). + template < + class TensorA, class TensorB, class TensorE, + class TensorSFA, class TensorSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& params, + MainloopPipelineNK pipeline, + PipelineStateNK smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for B + // + + Tensor gB_nkl = get<1>(load_inputs); + auto block_tma_b = params.tma_load_b.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Partition source and destination tensors for tma copies + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K, k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + using BarrierType = typename MainloopPipelineNK::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(params.tma_load_b.with(*tma_barrier), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + // Advance smem_pipe_write + ++k_tile_iter; + ++smem_pipe_write; + } + } + + // Local tile E from global memory. + template + CUTLASS_DEVICE auto + make_local_E(Params const& mainloop_params, + BlockCoord const& blk_coord) { + // E layout + auto layoutE = mainloop_params.layout_e; + // E data pointer as sparse datatype + auto ptr_E = recast_ptr(mainloop_params.ptr_E); + + // Global gmem E + Tensor gE = make_tensor(make_gmem_ptr(ptr_E), layoutE); // (BLK_M,BLK_K,BLK_L) + // Local tile E + return local_tile(gE, select<0,2>(TileShape{}), blk_coord); // (BLK_M,BLK_K) + } + + // Load E from global memory to registers. + template + CUTLASS_DEVICE auto + load_E(Params const& mainloop_params, + BlockCoord const& blk_coord, + ProblemShape_MNKL const& problem_shape_MNKL, + int thread_idx) { + // Workload + auto [M, N, K, L] = problem_shape_MNKL; + auto [m_coord, k_coord, l_coord] = blk_coord; + auto Shape_MK = cute::make_tuple(M, K); + + // Tiled mma and thread mma + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + // Tile shape + auto tile_shape_mnk = tile_shape(tiled_mma); + // Re-sue copy atom E from SmemCopyAtomE + using GmemCopyAtomeE = SmemCopyAtomE; + // Gmem tile copy + auto gmem_tiled_copy_E = make_tiled_copy_impl(GmemCopyAtomeE{}, + get_layoutE_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + // Gmem thread copy + auto gmem_thr_copy_E = gmem_tiled_copy_E.get_thread_slice(thread_idx); + // Gmem local E + auto gE_mkl = make_local_E(mainloop_params, blk_coord); + // Tiled gmem E + Tensor tCgE = gmem_thr_copy_E.partition_S(gE_mkl); // (CPY,CPY_M,CPY_K) + // Tiled register E and copy view + Tensor tCrE = partition_fragment_E(gE_mkl, thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrE_copy_view = gmem_thr_copy_E.retile_D(tCrE); // (CPY,CPY_M,CPY_K) + + if constexpr (IsF8F6F4) { + auto get_copy_atom_and_common_vec = [&]() { + using ValType = typename decltype(tCrE)::value_type; + // Get maximum copy vector size (logically) + auto common_layout = max_common_layout(tCgE, tCrE); + auto vec_elem = cute::min(size(common_layout), Int<128 / sizeof_bits_v>{}); + auto common_vec = composition(common_layout, vec_elem); + // Compose a Copy_Atom + using VecType = uint_bit_t>; + using cpy = Copy_Atom, ValType>; + return cute::make_tuple(cpy{}, common_vec); + }; + + // Copy depends on whether predication is needed + if constexpr (IsELoadPred) { + // Get predication based on logical element coordinates. + Tensor cE_mk = local_tile( + make_identity_tensor(Shape_MK), + make_shape(get<0>(TileShape{}), get<2>(TileShape{})), + make_shape(m_coord, k_coord)); // (BLK_M, BLK_K) + Tensor tCcE = gmem_thr_copy_E.partition_S(cE_mk); // (CPY,CPY_M,CPY_K) + auto [atom, vec] = get_copy_atom_and_common_vec(); + // Coordinate comparison for out of bound (OOB) predication + Tensor tZpE = cute::lazy::transform(zipped_divide(tCcE, vec), [&](auto const& c){ return cute::elem_less(c, Shape_MK); }); + // Copy + cute::copy_if(atom, tZpE, zipped_divide(tCgE, vec), zipped_divide(tCrE_copy_view, vec)); + } + else { + // Copy + cute::copy(cute::AutoVectorizingCopyWithAssumedAlignment<32>{}, tCgE, tCrE_copy_view); + } + } + return tCrE; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC, + class KTileIterator, + class CtaTileCoord, + class ProblemShape_MNKL + > + CUTLASS_DEVICE void + mma(MainloopPipelineMK pipeline_mk, + PipelineStateMK smem_pipe_read_mk, + MainloopPipelineNK pipeline_nk, + PipelineStateNK smem_pipe_read_nk, + FrgTensorC& accum, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params, + CtaTileCoord const& cta_tile_coord, + ProblemShape_MNKL const& problem_shape_MNKL) { + using namespace cute; + + CUTE_STATIC_ASSERT(is_rmem::value, "C tensor must be rmem resident."); + + clear(accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + auto SmemLayoutSFB_Ld = [SLayoutSFB = SmemLayoutSFB{}]() { + if constexpr (IsCtaN64) { + auto SLayoutSFB_tmp = SLayoutSFB; + auto new_shape = make_shape (make_shape(make_shape(shape<0,0,0>(SLayoutSFB_tmp), + shape<0,0,1>(SLayoutSFB_tmp) / _2{}), shape<0,1>(SLayoutSFB_tmp)), + shape<1>(SLayoutSFB_tmp), shape<2>(SLayoutSFB_tmp)); + auto new_stride = stride(SLayoutSFB_tmp); + return make_layout(new_shape, new_stride); + } + else { + return SLayoutSFB; + } + }(); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()) + + (IsCtaN64 && get<1>(cta_tile_coord) % 2 == 1 ? 8 : 0), SmemLayoutSFB_Ld); // (BLK_N,BLK_K,PIPE) + + // + // Define A/B/E partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + Tensor tCrE = partition_fragment_E(sE(_,_,Int<0>{}), thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrSFA = partition_fragment_SFA(sSFA(_,_,Int<0>{}), thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrSFB = partition_fragment_SFB(sSFB(_,_,Int<0>{}), thread_mma); // (MMA,MMA_N,MMA_K) + + // + // Copy Atom A, B and E retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + + auto tile_shape_mnk = tile_shape(tiled_mma); + auto smem_tiled_copy_E = make_tiled_copy_impl(SmemCopyAtomE{}, + get_layoutE_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(thread_idx); + Tensor tCsE = smem_thr_copy_E.partition_S( + as_position_independent_swizzle_tensor(sE)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrE_copy_view = smem_thr_copy_E.retile_D(tCrE); // (CPY,CPY_M,CPY_K) + + // SFA + auto smem_tiled_copy_SFA = make_tiled_copy_impl(SmemCopyAtomSFA{}, + get_layoutSFA_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFA = smem_tiled_copy_SFA.get_thread_slice(thread_idx); + Tensor tCsSFA = smem_thr_copy_SFA.partition_S( + as_position_independent_swizzle_tensor(sSFA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrSFA_copy_view = smem_thr_copy_SFA.retile_D(tCrSFA); // (CPY,CPY_M,CPY_K) + + // SFB + auto smem_tiled_copy_SFB = make_tiled_copy_impl(SmemCopyAtomSFB{}, + get_layoutSFB_TV(tiled_mma), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFB = smem_tiled_copy_SFB.get_thread_slice(thread_idx); + Tensor tCsSFB = smem_thr_copy_SFB.partition_S( + as_position_independent_swizzle_tensor(sSFB)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrSFB_copy_view = smem_thr_copy_SFB.retile_D(tCrSFB); // (CPY,CPY_N,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCsE) == size<1>(tCrE_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB) * Int{}); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == Int{}); + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == Int{}); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + + CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sA) == size<2>(sSFA)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sSFB) == Int{}); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sB) == Int{}); // PIPE + + if constexpr (UseSmemE) { + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sE)); + } + + // + // DEFINE FUNCTIONS FOR PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineStateMK smem_pipe_release_mk = smem_pipe_read_mk; + PipelineStateNK smem_pipe_release_nk = smem_pipe_read_nk; + + // Wait consumer barrier MK + auto wait_barrier_mk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto barrier_token_mk = pipeline_mk.consumer_try_wait(smem_pipe_read_mk); + pipeline_mk.consumer_wait(smem_pipe_read_mk, barrier_token_mk); + }; + + // Wait consumer barrier NK + auto wait_barrier_nk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto barrier_token_nk = pipeline_nk.consumer_try_wait(smem_pipe_read_nk); + pipeline_nk.consumer_wait(smem_pipe_read_nk, barrier_token_nk); + }; + + // Release consumer barrier MK, and move forward + auto release_advance_mk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + pipeline_mk.consumer_release(smem_pipe_release_mk); + ++smem_pipe_read_mk; + ++smem_pipe_release_mk; + }; + + // Release consumer barrier NK, and move forward + auto release_advance_nk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + pipeline_nk.consumer_release(smem_pipe_release_nk); + ++smem_pipe_read_nk; + ++smem_pipe_release_nk; + }; + + // Copy A from SMEM to register, and do transform if needed + auto copy_transform_A = [&](auto m_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,m_block,k_block,smem_pipe_read_mk.index()), tCrA_copy_view(_,m_block,k_block)); + // Perform transform if needed. + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,m_block,k_block)); + }; + + // Copy B from SMEM to register, and do transform if needed + auto copy_transform_B = [&](auto n_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // copy smem->rmem for B operand + copy(smem_tiled_copy_B, tCsB(_,n_block,k_block,smem_pipe_read_nk.index()), tCrB_copy_view(_,n_block,k_block)); + // Perform transform if needed. + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,n_block,k_block)); + }; + + // Copy SFA from SMEM to register + auto copy_SFA = [&](auto m_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // Copy smem->rmem for SFA operand + copy(tCsSFA(_,m_block,k_block,smem_pipe_read_mk.index()), tCrSFA_copy_view(_,m_block,k_block)); + }; + + // Copy SFB of all Ns from SMEM to register + auto copy_SFBs = [&](auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // Copy smem->rmem for SFB operand + copy(tCsSFB(_,_,k_block,smem_pipe_read_mk.index()), tCrSFB_copy_view(_,_,k_block)); + }; + + // Copy E from SMEM to register + auto copy_E = [&](auto m_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // copy smem->rmem for E operand + copy( recast(tCsE(_,m_block,k_block,smem_pipe_read_mk.index())), + recast(tCrE_copy_view(_,m_block,k_block))); + }; + + constexpr auto M_BLOCK_MAX = size<1>(tCrA); + constexpr auto N_BLOCK_MAX = size<1>(tCrB); + constexpr auto K_BLOCK_MAX = size<2>(tCrA); + constexpr auto K_BLOCK_STEP = K_BLOCK_MAX / Int{}; + + // Perform mainloop gemm, when E is in SMEM. + auto gemm_loop_with_SmemE = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + // WAIT on smem_pipe_read until data is available + wait_barrier_mk(); + wait_barrier_nk(); + + // Load A/B/E/SFA/SFB, then do gemm. + for_each(make_int_sequence{}, [&] (auto k_block) { + // Copy smem->rmem for A/B/E operand + copy_transform_A(_, k_block); + copy_transform_B(_, k_block); + copy_E(_, k_block); + + // Copy smem->rmem for SFA/SFB operand + copy_SFA(_, k_block); + copy_SFBs(k_block); + + // Gemm + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,k_block), tCrSFA(_,_,k_block), tCrE(_,_,k_block)), + make_zip_tensor(tCrB(_,_,k_block), tCrSFB(_,_,k_block)), + accum); + + }); + + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + + // Advance consumer pipeline mk/nk + release_advance_mk(); + release_advance_nk(); + }; + + // Perform mainloop gemm, when E is in GMEM. + auto gemm_loop_with_GmemE = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + // Copy gmem->rmem for E operand + auto blk_coord = make_coord(get<0>(cta_tile_coord), *k_tile_iter, get<3>(cta_tile_coord)); // (BLK_M,BLK_K,L) + Tensor tCrE = load_E(mainloop_params, blk_coord, problem_shape_MNKL, thread_idx); + ++k_tile_iter; + + // WAIT on smem_pipe_read until data is available + wait_barrier_mk(); + wait_barrier_nk(); + + for_each(make_int_sequence{}, [&] (auto k_block) { + // Copy smem->rmem for SFB operand. SFB needs to be copied with all N_BLOCK_MAX, + // as each LDS loads several groups of data needed by one MMA instruction. + copy_SFBs(k_block); + + for_each(make_int_sequence{}, [&] (auto n_block) { + // Copy smem->rmem for B operand + copy_transform_B(n_block, k_block); + + for_each(make_int_sequence{}, [&] (auto m_block) { + // Copy smem->rmem for A operand + copy_transform_A(m_block, k_block); + copy_SFA(m_block, k_block); + + // Gemm + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block), tCrSFA(_,m_block,k_block), tCrE(_,m_block,k_block)), + make_zip_tensor(tCrB(_,n_block,k_block), tCrSFB(_,n_block,k_block)), + accum(_,m_block,n_block)); + }); + }); + }); + + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + + // Advance consumer pipeline_nk + release_advance_nk(); + // Wait next buffer + wait_barrier_nk(); + + for_each(make_int_sequence{}, [&] (auto k_block) { + auto k_block_a = k_block + K_BLOCK_STEP; + + // Copy smem->rmem for SFB operand. SFB needs to be copied with all N_BLOCK_MAX, + // as each LDS loads several groups of data needed by one MMA instruction. + copy_SFBs(k_block_a); + + for_each(make_int_sequence{}, [&] (auto n_block) { + // Copy smem->rmem for B operand + copy_transform_B(n_block, k_block); + + for_each(make_int_sequence{}, [&] (auto m_block) { + // Copy smem->rmem for A operand + copy_transform_A(m_block, k_block_a); + copy_SFA(m_block, k_block_a); + + // Gemm + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block_a), tCrSFA(_,m_block,k_block_a), tCrE(_,m_block,k_block_a)), + make_zip_tensor(tCrB(_,n_block,k_block), tCrSFB(_,n_block,k_block_a)), + accum(_,m_block,n_block)); + }); + }); + }); + + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + + // Advance consumer pipeline mk/nk + release_advance_mk(); + release_advance_nk(); + }; + + // + // PIPELINED MAIN LOOP + // + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // Case when A/B with same stages, and keep E in SMEM. + if constexpr (UseSmemE) { + gemm_loop_with_SmemE(); + } + // Case when A/B with different stages, and keep E in GMEM. + else { + gemm_loop_with_GmemE(); + } // end if + + } + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipelineMK, PipelineStateMK, MainloopPipelineNK, PipelineStateNK, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3fc3d583c9b8880b49bf68933ea11ed13cb68ad4 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp @@ -0,0 +1,1001 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling; + using TileShape = TileShape_; + using ElementA = remove_cvref_t; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using InternalStrideB = cute::remove_pointer_t; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementSF = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = size<0,0>(InternalLayoutSFA{}); + static constexpr int ScaleGranularityN = size<0,0>(InternalLayoutSFB{}); + static constexpr int ScaleGranularityK = size<1,0>(InternalLayoutSFB{}); + + static_assert(size<1, 0>(InternalLayoutSFA{}) == size<1, 0>(InternalLayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M."); + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N."); + static_assert(size<2>(TileShape{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K."); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig(InternalLayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + static constexpr int AlignmentSFA = 1; + static constexpr int AlignmentSFB = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + // we can have partial tiles in M or N, so don't vectorize those loads + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementSF>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementSF>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout, Int>>; + + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + using TmaInternalElementB = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::array_aligned> smem_A; + alignas(1024) cute::array_aligned> smem_B; + cute::array_aligned> smem_scale_A; + cute::array_aligned> smem_scale_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shapes, Arguments const& args, void* workspace) { + (void) workspace; + + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + constexpr int tma_alignment_bits = 128; + auto init_M = tma_alignment_bits; + auto init_N = tma_alignment_bits; + auto init_K = tma_alignment_bits; + const uint32_t init_L = 1; + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + InternalStrideA stride_a; + InternalStrideB stride_b; + + if constexpr (IsGroupedGemmKernel) { + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M, init_K, init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N, init_K, init_L), stride_b)); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + args.layout_SFA, + reinterpret_cast(args.ptr_SFB), + args.layout_SFB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTmaTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + return (NumInputTmaTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shapes.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + // Ensure complete scale blocks + implementable = implementable && (M % ScaleGranularityM == 0); + implementable = implementable && (N % ScaleGranularityN == 0); + + // We expect full tiles in K + implementable = implementable && (K % size<2>(TileShape{}) == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for blockwise scaling.\n"); + } + } + } + + return implementable; + } + + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params, + ElementSF const* ptr_SFA = nullptr, + ElementSF const* ptr_SFB = nullptr, + InternalLayoutSFA const layout_SFA = InternalLayoutSFA{}, + InternalLayoutSFB const layout_SFB = InternalLayoutSFB{} + ) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + const int32_t init_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,init_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,init_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(ptr_SFA), filter(layout_SFA)); // (Ms, Ks) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(ptr_SFB), filter(layout_SFB)); // (Ns, Ks) + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class TensorMapA, class TensorMapB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mSFA_mkl.shape()); + auto scales_n = get<0>(mSFB_nkl.shape()); + + Tensor cSFA_mkl = make_identity_tensor(mSFA_mkl.shape()); + Tensor cSFB_nkl = make_identity_tensor(mSFB_nkl.shape()); + Tensor gSFA = local_tile( + mSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cSFA = local_tile( + cSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gSFB = local_tile( + mSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1) + Tensor cSFB = local_tile( + cSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); + + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); + + Tensor tAgA_SFA = thr_scale_copy_a.partition_S(gSFA); + Tensor tAcA_SFA = thr_scale_copy_a.partition_S(cSFA); + Tensor tAsA_SFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tBgB_SFB = thr_scale_copy_b.partition_S(gSFB); + Tensor tBcB_SFB = thr_scale_copy_b.partition_S(cSFB); + Tensor tBsB_SFB = thr_scale_copy_b.partition_D(sSFB); + + Tensor tApA_SFA = make_tensor(shape(tAsA_SFA(_,_,0))); + Tensor tBpB_SFB = make_tensor(shape(tBsB_SFB(_,_,0))); + + auto scale_m_lim = std::min(scales_m, (m_coord + 1) * ScaleMsPerTile); + auto scale_n_lim = std::min(scales_n, (n_coord + 1) * ScaleNsPerTile); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tApA_SFA); ++i) + tApA_SFA(i) = get<0>(tAcA_SFA(i)) < scale_m_lim; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tBpB_SFB); ++i) + tBpB_SFB(i) = get<0>(tBcB_SFB(i)) < scale_n_lim; + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // TMA Multicast Masks + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + int write_stage = smem_pipe_write.index(); + if (lane_predicate) { + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + + // Copy scale tensors + copy_if(scale_copy_a, tApA_SFA, tAgA_SFA(_,_,*k_tile_iter), tAsA_SFA(_,_,write_stage)); + copy_if(scale_copy_b, tBpB_SFB, tBgB_SFB(_,_,*k_tile_iter), tBsB_SFB(_,_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + __syncwarp(); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + FrgTensorC tmp_accum; + clear(accum); + clear(tmp_accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),TileShape_N,stage) + Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + Layout< + Shape, Shape, Int>, Int>, + Stride<_0, Stride<_0, _1>, Int> + >{}); // (TileShape_M,(ScaleGranularityN,ScaleNsPerTile),stage) + + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCsScaleAViewAsC = thread_mma.partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsScaleBViewAsC = thread_mma.partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + + // + // Copy Atom A and B retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + Tensor tCrScaleAViewAsC = make_tensor_like(tCsScaleAViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleBViewAsC = make_tensor_like(tCsScaleBViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + }; + + auto copy_scale_s2r = [&](auto read_stage) { + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementSF scale_b = tCrScaleBViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementSF scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } + }; + + auto rescale = [&]() { + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementSF scale_ab = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * scale_ab; + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tmp_accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + copy_scale_s2r(read_stage); + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + if (k_block == K_BLOCK_MAX - 1) { + rescale(); + copy_scale_s2r(read_stage); + } + + }); + + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); + rescale(); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } + + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + template + CUTLASS_DEVICE void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + template + CUTLASS_DEVICE InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + Params const& mainloop_params, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if constexpr (IsGroupedGemmKernel) { + return load_init( + problem_shape_mnkl, + mainloop_params, + mainloop_params.ptr_SFA[next_batch], + mainloop_params.ptr_SFB[next_batch], + mainloop_params.layout_SFA[next_batch], + mainloop_params.layout_SFB[next_batch] + ); + } + else { + auto [gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl] = input_tensors; + + mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA[next_batch]), mainloop_params.layout_SFA[next_batch]); + mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB[next_batch]), mainloop_params.layout_SFB[next_batch]); + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65f83330a76d56a26aa5b9c5c2531828660dd22a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma.hpp @@ -0,0 +1,587 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120TmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120TmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + using TmaInternalElementB = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::array_aligned> smem_A; + alignas(1024) cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + clear(accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + + // + // Copy Atom A and B retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + }); + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2f77d66468788789801044fa95bb5528e9aa051c --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp @@ -0,0 +1,779 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120TmaWarpSpecializedBlockwiseScaling, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120TmaWarpSpecializedBlockwiseScaling; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using ElementB = ElementB_; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementSF = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{}); + static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{}); + static constexpr int ScaleGranularityK = size<1,0>(LayoutSFB{}); + + static_assert(size<1, 0>(LayoutSFA{}) == size<1, 0>(LayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M."); + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N."); + static_assert(size<2>(TileShape{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K."); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig(LayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(LayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + static constexpr int AlignmentSFA = 1; + static constexpr int AlignmentSFB = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + // we can have partial tiles in M or N, so don't vectorize those loads + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementSF>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementSF>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout, Int>>; + + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + using TmaInternalElementB = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::array_aligned> smem_A; + alignas(1024) cute::array_aligned> smem_B; + cute::array_aligned> smem_scale_A; + cute::array_aligned> smem_scale_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementSF const* ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const* ptr_SFB; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK, + args.ptr_SFA, + args.layout_SFA, + args.ptr_SFB, + args.layout_SFB + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + // Ensure complete scale blocks + implementable = implementable && (M % ScaleGranularityM == 0); + implementable = implementable && (N % ScaleGranularityN == 0); + + // We expect full tiles in K + implementable = implementable && (K % size<2>(TileShape{}) == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the alignment requirements for blockwise scaling.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), filter(mainloop_params.layout_SFA)); // (Ms, Ks) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), filter(mainloop_params.layout_SFB)); // (Ns, Ks) + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mSFA_mkl.shape()); + auto scales_n = get<0>(mSFB_nkl.shape()); + + Tensor cSFA_mkl = make_identity_tensor(mSFA_mkl.shape()); + Tensor cSFB_nkl = make_identity_tensor(mSFB_nkl.shape()); + Tensor gSFA = local_tile( + mSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cSFA = local_tile( + cSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gSFB = local_tile( + mSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1) + Tensor cSFB = local_tile( + cSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); + + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); + + Tensor tAgA_SFA = thr_scale_copy_a.partition_S(gSFA); + Tensor tAcA_SFA = thr_scale_copy_a.partition_S(cSFA); + Tensor tAsA_SFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tBgB_SFB = thr_scale_copy_b.partition_S(gSFB); + Tensor tBcB_SFB = thr_scale_copy_b.partition_S(cSFB); + Tensor tBsB_SFB = thr_scale_copy_b.partition_D(sSFB); + + Tensor tApA_SFA = make_tensor(shape(tAsA_SFA(_,_,0))); + Tensor tBpB_SFB = make_tensor(shape(tBsB_SFB(_,_,0))); + + auto scale_m_lim = std::min(scales_m, (m_coord + 1) * ScaleMsPerTile); + auto scale_n_lim = std::min(scales_n, (n_coord + 1) * ScaleNsPerTile); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tApA_SFA); ++i) + tApA_SFA(i) = get<0>(tAcA_SFA(i)) < scale_m_lim; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tBpB_SFB); ++i) + tBpB_SFB(i) = get<0>(tBcB_SFB(i)) < scale_n_lim; + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // TMA Multicast Masks + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + int write_stage = smem_pipe_write.index(); + if (lane_predicate) { + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + + // Copy scale tensors + copy_if(scale_copy_a, tApA_SFA, tAgA_SFA(_,_,*k_tile_iter), tAsA_SFA(_,_,write_stage)); + copy_if(scale_copy_b, tBpB_SFB, tBgB_SFB(_,_,*k_tile_iter), tBsB_SFB(_,_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + FrgTensorC tmp_accum; + clear(accum); + clear(tmp_accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),TileShape_N,stage) + Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + Layout< + Shape, Shape, Int>, Int>, + Stride<_0, Stride<_0, _1>, Int> + >{}); // (TileShape_M,(ScaleGranularityN,ScaleNsPerTile),stage) + + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCsScaleAViewAsC = thread_mma.partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsScaleBViewAsC = thread_mma.partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + + // + // Copy Atom A and B retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + Tensor tCrScaleAViewAsC = make_tensor_like(tCsScaleAViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleBViewAsC = make_tensor_like(tCsScaleBViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + }; + + auto copy_scale_s2r = [&](auto read_stage) { + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementSF scale_b = tCrScaleBViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementSF scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } + }; + + auto rescale = [&]() { + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementSF scale_ab = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * scale_ab; + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tmp_accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + copy_scale_s2r(read_stage); + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + if (k_block == K_BLOCK_MAX - 1) { + rescale(); + copy_scale_s2r(read_stage); + } + + }); + + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); + rescale(); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7eec27bcf27acf8d7b93936b83991955ee37b854 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp @@ -0,0 +1,988 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/collective/builders/sm1xx_sparse_config.inl" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesA, + int StagesB, + int StagesE, + int SchedulerPipelineStageCount, + class ClusterShape, + class TileShape_, + class ElementA_, + class LayoutPairAE_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomPairA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120TmaWarpSpecializedSparse, + TileShape_, + ElementA_, + LayoutPairAE_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomPairA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm120TmaWarpSpecializedSparse; + using TileShape = TileShape_; + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutPairAE = LayoutPairAE_; + using LayoutA = remove_cvref_t(LayoutPairAE{}))>; + using LayoutE = remove_cvref_t(LayoutPairAE{}))>; + using StrideA = remove_cvref_t(LayoutPairAE{}))>; + using ElementB = ElementB_; + using StrideB = StrideB_; + using ElementBMma = typename TiledMma::ValTypeB; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = remove_cvref_t(SmemCopyAtomPairA_{}))>; + using SmemCopyAtomE = remove_cvref_t(SmemCopyAtomPairA_{}))>; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using GmemTiledCopyE = GmemTiledCopyA_; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + using RegisterE = typename remove_extent::type; + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + static constexpr int ThreadCount = size(TiledMma{}); + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + // Asymmetric buffering + // Tensor A/B could have different buffering, with TILEK, and STAGEs. + // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's + // pipeline keep same steps when produce / consume data. + static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; + + using TileShapeB = decltype(make_shape(size<0>(TileShape{}), + size<1>(TileShape{}), + ceil_div(size<2>(TileShape{}), Int{}))); + + // Use two MainloopPipeline for A and B separately. + using MainloopPipelineMK = cutlass::PipelineTmaAsync; + using MainloopPipelineNK = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipelineMK::Params; + using PipelineStateMK = typename cutlass::PipelineState; + using PipelineStateNK = typename cutlass::PipelineState; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + static_assert(DispatchPolicy::StagesA >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(DispatchPolicy::StagesB >= 2, "Specialization requires Stages set to value 2 or more."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShapeB{}), shape<2>(TileShapeB{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_sparse_f8f6f4(); + + // Is E kept in SMEM or GMEM + static constexpr bool UseSmemE = DispatchPolicy::StagesE != 0; + + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using TmaInternalElementB = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + // Set shared memory layout + using SmemAllocTypeA = cute::conditional_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t; + + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{})), Int>; + using SparseConfig = cutlass::Sm1xxGemmSparseConfig< + ElementAMma, + cute::conditional_t, + ElementEMma>; + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + static constexpr int SmemSizeE = UseSmemE ? cosize(SmemLayoutE{}) : 0; + static constexpr int StageSizeE = UseSmemE ? cosize(take<0,2>(SmemLayoutE{})) : 0; + // Check if metetata fetching needs predicator + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + static constexpr bool IsELoadPred = not (TensorEAtomM{} == size<0>(TileShape{}) && TensorEAtomK{} == size<2>(TileShape{})); + + static_assert(rank(SmemLayoutAtomE{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomE{})) == 0, "SmemLayoutAtomE must evenly divide tile shape."); + + // Set the bytes transferred in this TMA transaction + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(StageSizeE * cute::sizeof_bits_v)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutB{})) * cute::sizeof_bits_v)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B; + cute::ArrayEngine{}> smem_E; + } tensors; + + using PipelineStorageMK = typename MainloopPipelineMK::SharedStorage; + using PipelineStorageNK = typename MainloopPipelineNK::SharedStorage; + alignas(16) PipelineStorageMK pipeline_storage_mk; + alignas(16) PipelineStorageNK pipeline_storage_nk; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorageMK = typename SharedStorage::PipelineStorageMK; + using PipelineStorageNK = typename SharedStorage::PipelineStorageNK; + + struct Arguments { + ElementA const* ptr_A{nullptr}; + LayoutA layout_a{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementE const* ptr_E{nullptr}; + LayoutE layout_e{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr>(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); + + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShapeB{}), shape<2>(TileShapeB{})), + _1{})); + using TMA_E = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_E tma_load_e; + LayoutA layout_a; + LayoutE layout_e; + ElementE const* ptr_E{nullptr}; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr>(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + auto ptr_E = recast_ptr(args.ptr_E); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShapeB{}), shape<2>(TileShapeB{})), + _1{}); + typename Params::TMA_E tma_load_e = make_tma_copy( + GmemTiledCopyE{}, + tensor_e, + SmemLayoutE{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); + return { + tma_load_a, + tma_load_b, + tma_load_e, + args.layout_a, + args.layout_e, + args.ptr_E + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::upcast<2>(make_layout(make_shape(M, K, L), StrideA{}))); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + if constexpr (UseSmemE) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_e.get_tma_descriptor()); + } + } + + /// Create fragment for metadata. The function is referred from thrfrg_A(...) + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_E(Tensor&& tensor, TiledMMA& mma) { + CUTE_STATIC_ASSERT_V(rank(tensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutE_TV = typename Atom::Traits::ELayout; + + auto t_tile = make_tile(get<0>(TiledPerm{}), + get<2>(TiledPerm{})); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + auto t_tensor = logical_divide(tensor, t_tile); + + // Tile the tensor for the Atom + auto e_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto e_tensor = zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = e_tensor.compose(AtomLayoutE_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + // Fragment layout + return thr_tensor; + } + + /// get metadata TV + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutE_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_E = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_E(ref_E, mma).compose(etile, _).compose(thridx_2_thrid, _); + } + + /// Partitioning for metadata. + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_E(Tensor&& tensor, ThrMma& thread_mma) { + auto thr_tensor = make_tensor(static_cast(tensor).data(), thrfrg_E(tensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition.layout()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(mainloop_params.layout_a.shape()); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mE_mkl = mainloop_params.tma_load_e.get_tma_tensor(mainloop_params.layout_e.shape()); // (m,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShapeB{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gE_mkl = local_tile(mE_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gE_mkl); + } + + /// Issues loads for A/E only (used when DMA warp is split). + template < + class TensorA, class TensorB, class TensorE, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& mainloop_params, + MainloopPipelineMK pipeline, + PipelineStateMK smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + + // Prepare the TMA loads for A and B + Tensor gA_mkl = get<0>(load_inputs); + Tensor gE_mkl = get<2>(load_inputs); + auto block_tma_a = mainloop_params.tma_load_a.get_slice(0); + auto block_tma_e = mainloop_params.tma_load_e.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K, k) + Tensor gE = gE_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K, k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K, k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tEgE = block_tma_e.partition_S(gE); // (TMA,TMA_M,TMA_K, k) + Tensor tEsE = block_tma_e.partition_D(sE); // (TMA,TMA_M,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipelineMK::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_a.with(*tma_barrier), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if constexpr (UseSmemE) { + copy(mainloop_params.tma_load_e.with(*tma_barrier), tEgE(_,_,_,*k_tile_iter), tEsE(_,_,_,write_stage)); + } + } + + if constexpr (!UseSmemE) { + auto blk_coord_mkl = make_coord(get<0>(blk_coord), *k_tile_iter, get<3>(blk_coord)); // (BLK_M,BLK_K,L) + prefetch(make_local_E(mainloop_params, blk_coord_mkl)); + } + + // Advance smem_pipe_write + ++k_tile_iter; + ++smem_pipe_write; + } + } + + /// Issues loads for B only (used when DMA warp is split). + template < + class TensorA, class TensorB, class TensorE, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& mainloop_params, + MainloopPipelineNK pipeline, + PipelineStateNK smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Prepare the TMA loads for A and B + Tensor gB_nkl = get<1>(load_inputs); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K, k) + + // Applies the mapping from block_tma_a + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K, k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipelineNK::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_b.with(*tma_barrier), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + + // Advance smem_pipe_write + ++k_tile_iter; + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + template + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + // Local tile E from global memory. + template + CUTLASS_DEVICE auto + make_local_E(Params const& mainloop_params, + BlockCoord const& blk_coord) { + // E layout + auto layoutE = mainloop_params.layout_e; + // E data pointer as sparse datatype + auto ptr_E = recast_ptr(mainloop_params.ptr_E); + + // Global gmem E + Tensor gE = make_tensor(make_gmem_ptr(ptr_E), layoutE); // (BLK_M,BLK_K,BLK_L) + // Local tile E + return local_tile(gE, select<0,2>(TileShape{}), blk_coord); // (BLK_M,BLK_K) + } + + // Load E from global memory to registers. + template + CUTLASS_DEVICE auto + load_E(Params const& mainloop_params, + BlockCoord const& blk_coord, + ProblemShape_MNKL const& problem_shape_MNKL, + int thread_idx) { + // Workload + auto [M, N, K, L] = problem_shape_MNKL; + auto [m_coord, k_coord, l_coord] = blk_coord; + auto Shape_MK = cute::make_tuple(M, K); + + // Tiled mma and thread mma + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + // Tile shape + auto tile_shape_mnk = tile_shape(tiled_mma); + // Re-sue copy atom E from SmemCopyAtomE + using GmemCopyAtomeE = SmemCopyAtomE; + // Gmem tile copy + auto gmem_tiled_copy_E = make_tiled_copy_impl(GmemCopyAtomeE{}, + get_layoutE_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + // Gmem thread copy + auto gmem_thr_copy_E = gmem_tiled_copy_E.get_thread_slice(thread_idx); + // Gmem local E + auto gE_mkl = make_local_E(mainloop_params, blk_coord); + // Tiled gmem E + Tensor tCgE = gmem_thr_copy_E.partition_S(gE_mkl); // (CPY,CPY_M,CPY_K) + // Tiled register E and copy view + Tensor tCrE = partition_fragment_E(gE_mkl, thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrE_copy_view = gmem_thr_copy_E.retile_D(tCrE); // (CPY,CPY_M,CPY_K) + + if constexpr (IsF8F6F4) { + auto get_copy_atom_and_common_vec = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + using ValType = typename decltype(tCrE)::value_type; + // Get maximum copy vector size (logically) + auto common_layout = max_common_layout(tCgE, tCrE); + auto vec_elem = cute::min(size(common_layout), Int<128 / sizeof_bits_v>{}); + auto common_vec = composition(common_layout, vec_elem); + // Compose a Copy_Atom + using VecType = uint_bit_t>; + using cpy = Copy_Atom, ValType>; + return cute::make_tuple(cpy{}, common_vec); + }; + + // Copy depends on whether predication is needed + if constexpr (IsELoadPred) { + // Get predication based on logical element coordinates. + Tensor cE_mk = local_tile( + make_identity_tensor(Shape_MK), + make_shape(get<0>(TileShape{}), get<2>(TileShape{})), + make_shape(m_coord, k_coord)); // (BLK_M, BLK_K) + Tensor tCcE = gmem_thr_copy_E.partition_S(cE_mk); // (CPY,CPY_M,CPY_K) + auto [atom, vec] = get_copy_atom_and_common_vec(); + // Coordinate comparison for out of bound (OOB) predication + Tensor tZpE = cute::lazy::transform(zipped_divide(tCcE, vec), [&](auto const& c){ return cute::elem_less(c, Shape_MK); }); + // Copy + cute::copy_if(atom, tZpE, zipped_divide(tCgE, vec), zipped_divide(tCrE_copy_view, vec)); + } + else { + // Copy + cute::copy(cute::AutoVectorizingCopyWithAssumedAlignment<32>{}, tCgE, tCrE_copy_view); + } + } + return tCrE; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC, + class KTileIterator, + class CtaTileCoord, + class ProblemShape_MNKL + > + CUTLASS_DEVICE void + mma(MainloopPipelineMK pipeline_mk, + PipelineStateMK smem_pipe_read_mk, + MainloopPipelineNK pipeline_nk, + PipelineStateNK smem_pipe_read_nk, + FrgTensorC& accum, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params, + CtaTileCoord const& cta_tile_coord, + ProblemShape_MNKL const& problem_shape_MNKL) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + clear(accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + + // + // Define A/B/E partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + Tensor tCrE = partition_fragment_E(sE(_,_,Int<0>{}), thread_mma); // (MMA,MMA_M,MMA_K) + + // + // Copy Atom A, B and E retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + + auto tile_shape_mnk = tile_shape(tiled_mma); + auto smem_tiled_copy_E = make_tiled_copy_impl(SmemCopyAtomE{}, + get_layoutE_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(thread_idx); + Tensor tCsE = smem_thr_copy_E.partition_S( + as_position_independent_swizzle_tensor(sE)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrE_copy_view = smem_thr_copy_E.retile_D(tCrE); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCsE) == size<1>(tCrE_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB) * Int{}); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == Int{}); + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == Int{}); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + if constexpr (UseSmemE) { + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sE)); + } + + // + // DEFINE FUNCTIONS FOR PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineStateMK smem_pipe_release_mk = smem_pipe_read_mk; + PipelineStateNK smem_pipe_release_nk = smem_pipe_read_nk; + + // Wait consumer barrier MK + auto wait_barrier_mk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto barrier_token_mk = pipeline_mk.consumer_try_wait(smem_pipe_read_mk); + pipeline_mk.consumer_wait(smem_pipe_read_mk, barrier_token_mk); + }; + + // Wait consumer barrier NK + auto wait_barrier_nk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto barrier_token_nk = pipeline_nk.consumer_try_wait(smem_pipe_read_nk); + pipeline_nk.consumer_wait(smem_pipe_read_nk, barrier_token_nk); + }; + + // Release consumer barrier MK, and move forward + auto release_advance_mk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + pipeline_mk.consumer_release(smem_pipe_release_mk); + ++smem_pipe_read_mk; + ++smem_pipe_release_mk; + }; + + // Release consumer barrier NK, and move forward + auto release_advance_nk = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + pipeline_nk.consumer_release(smem_pipe_release_nk); + ++smem_pipe_read_nk; + ++smem_pipe_release_nk; + }; + + // Copy A from SMEM to register, and do transform if needed + auto copy_transform_A = [&](auto m_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,m_block,k_block,smem_pipe_read_mk.index()), tCrA_copy_view(_,m_block,k_block)); + // Perform transform if needed. + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA(_,m_block,k_block)); + }; + + // Copy B from SMEM to register, and do transform if needed + auto copy_transform_B = [&](auto n_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // copy smem->rmem for B operand + copy(smem_tiled_copy_B, tCsB(_,n_block,k_block,smem_pipe_read_nk.index()), tCrB_copy_view(_,n_block,k_block)); + // Perform transform if needed. + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_B(MMAOp{}, tCrB(_,n_block,k_block)); + }; + + // Copy E from SMEM to register + auto copy_E = [&](auto m_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { + // copy smem->rmem for E operand + copy( recast(tCsE(_,m_block,k_block,smem_pipe_read_mk.index())), + recast(tCrE_copy_view(_,m_block,k_block))); + }; + + // TILE M/N/K for one TILE block + constexpr auto M_BLOCK_MAX = size<1>(tCrA); + constexpr auto N_BLOCK_MAX = size<1>(tCrB); + constexpr auto K_BLOCK_MAX = size<2>(tCrA); + constexpr auto K_BLOCK_STEP = K_BLOCK_MAX / Int{}; + + // Perform mainloop gemm, when E is in SMEM. + auto gemm_loop_with_SmemE = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + // WAIT on smem_pipe_read until data is available + wait_barrier_mk(); + wait_barrier_nk(); + + // Load A/B/E, then do gemm. + for_each(make_int_sequence{}, [&] (auto k_block) { + for_each(make_int_sequence{}, [&] (auto n_block) { + // Copy smem->rmem for B operand + copy_transform_B(n_block, k_block); + + for_each(make_int_sequence{}, [&] (auto m_block) { + // Copy smem->rmem for A operand + copy_transform_A(m_block, k_block); + copy_E(m_block, k_block); + + // Gemm + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block), tCrE(_,m_block,k_block)), + tCrB(_,n_block,k_block), + accum(_,m_block,n_block)); + }); + }); + }); + + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + + // Advance consumer pipeline mk/nk + release_advance_mk(); + release_advance_nk(); + }; + + // Perform mainloop gemm, when E is in GMEM. + auto gemm_loop_with_GmemE = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + // Copy gmem->rmem for E operand + auto blk_coord = make_coord(get<0>(cta_tile_coord), *k_tile_iter, get<3>(cta_tile_coord)); // (BLK_M,BLK_K,L) + Tensor tCrE = load_E(mainloop_params, blk_coord, problem_shape_MNKL, thread_idx); + ++k_tile_iter; + + // WAIT on smem_pipe_read until data is available + wait_barrier_mk(); + wait_barrier_nk(); + + for_each(make_int_sequence{}, [&] (auto k_block) { + for_each(make_int_sequence{}, [&] (auto n_block) { + // Copy smem->rmem for B operand + copy_transform_B(n_block, k_block); + + for_each(make_int_sequence{}, [&] (auto m_block) { + // Copy smem->rmem for A operand + copy_transform_A(m_block, k_block); + + // Gemm + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block), tCrE(_,m_block,k_block)), + tCrB(_,n_block,k_block), + accum(_,m_block,n_block)); + }); + }); + }); + + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + + // Advance consumer pipeline_nk + release_advance_nk(); + // Wait next buffer + wait_barrier_nk(); + + for_each(make_int_sequence{}, [&] (auto k_block) { + auto k_block_a = k_block + K_BLOCK_STEP; + for_each(make_int_sequence{}, [&] (auto n_block) { + // Copy smem->rmem for B operand + copy_transform_B(n_block, k_block); + + for_each(make_int_sequence{}, [&] (auto m_block) { + // Copy smem->rmem for A operand + copy_transform_A(m_block, k_block_a); + + // Gemm + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block_a), tCrE(_,m_block,k_block_a)), + tCrB(_,n_block,k_block), + accum(_,m_block,n_block)); + }); + }); + }); + + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + + // Advance consumer pipeline mk/nk + release_advance_mk(); + release_advance_nk(); + }; + + + // + // PIPELINED MAIN LOOP + // + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // Case when A/B with same stages, and keep E in SMEM. + if constexpr (UseSmemE) { + gemm_loop_with_SmemE(); + } + // Case when A/B with different stages, and keep E in GMEM. + else { + gemm_loop_with_GmemE(); + } // end if + + } // end loop k_tile_count + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipelineMK, PipelineStateMK, MainloopPipelineNK, PipelineStateNK, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a1b6f8589a249ce7fe9112d8be3f6a4f83eebc4a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -0,0 +1,600 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm70TwoStageUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm70TwoStageUnpredicated; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + (void)residue_mnk; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + static_assert(cute::rank(SmemLayoutB{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto copy_a_thr = gmem_tiled_copy_a.get_slice(thread_idx); + auto copy_b_thr = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = copy_a_thr.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = copy_a_thr.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) + Tensor tBgB = copy_b_thr.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = copy_b_thr.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) + + // Allocate the register tiles for double buffering -- same shape as partitioned data + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_M,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + + auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + + // + // Prologue + // + + // Copy gmem to rmem for the first k_tile + copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); + copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); + if (--k_tile_count > 0) ++k_tile_iter; + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + // Clear accumulators + __syncthreads(); + + // Load A, B smem->rmem for k=0 + copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0)); + // + // Mainloop + // + + // Size of the k-tiles's outer product mode (k) + auto K_BLOCK_MAX = size<2>(tCrA); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > -1) + { + // Pipeline the outer products with a static for loop + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + __syncthreads(); + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + } + + // Load A, B smem->rmem for k+1 + int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + if (k_block == 0) + { + // Copy gmem to rmem + copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); + copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBrB); + if (--k_tile_count > 0) ++k_tile_iter; + } + + // transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + + // Thread-level register gemm for k + // disambiguate gemm (shared with the namespace name) + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm70TwoStage, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm70TwoStage; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + static_assert(cute::rank(SmemLayoutB{}) == 2, + "MainloopTwoStage must not have a smem shape with a pipeline mode."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA.data() = &gA(0, get<2>(residue_mnk), 0); + gB.data() = &gB(0, get<2>(residue_mnk), 0); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Allocate the register tiles for double buffering -- same shape as partitioned data + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_N,BCPY_K) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the rmem tiles to account for predicated off loads + clear(tArA); + clear(tBrB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tArA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tArA(_,_,k)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBrB(_,_,k)); + } + } + ++k_tile_iter; + --k_tile_count; + } + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB)); // (MMA,MMA_M,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + + auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + + // + // Prologue + // + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + // Clear accumulators + __syncthreads(); + + // Load A, B smem->rmem for k=0 + copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0)); + // + // Mainloop + // + + // Size of the k-tiles's outer product mode (k) + auto K_BLOCK_MAX = size<2>(tCrA); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > -1) + { + // Pipeline the outer products with a static for loop + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + __syncthreads(); + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + } + + // Load A, B smem->rmem for k+1 + int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + if (k_block == 0) + { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tArA); + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBrB); + ++k_tile_iter; + --k_tile_count; + } + + // transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + + // Thread-level register gemm for k + // disambiguate gemm (shared with the namespace name) + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b83e04891244a840339af1f639fa3bbe74c58d66 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm80_mma_array_multistage.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_ +> +struct CollectiveMma< + MainloopSm80ArrayCpAsync< + Stages, + ClusterShape_>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + > +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80ArrayCpAsync< + Stages, + ClusterShape_>; + using TileShape = TileShape_; + // Follow the change in TestSmall: TileShape switch to CtaShape + // In legacy arch, it should be same + using CtaShape_MNK = TileShape; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, // (BLK_M, BLK_K, K_TILES) + TensorB gB, // (BLK_N, BLK_K, K_TILES) + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA); + gB = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the smem tiles to account for predicated off loads + clear(tAsA); + clear(tBsB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + constexpr int k_pipe = 0; + + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); + } + } + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Start async loads for 1st k-tile onwards, no k-residue handling needed + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Set all predicates to false if we are going to overshoot bounds + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + ++k_tile_iter; + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + + cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2e3e394dc10a5d18eebf7e185894c1a9de303e8a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -0,0 +1,706 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm80CpAsyncUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + > +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80CpAsyncUnpredicated; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + // Follow the change in TestSmall: TileShape switch to CtaShape + // For sm80 arch, CtaShape should equal to TileShape + using CtaShape_MNK = TileShape; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, + "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + (void) residue_mnk; + //assert(residue_mnk == make_tuple(0,0,0)); + + // + // PREFETCH + // + + // Start async loads for all pipes but the last + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); + copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); + cp_async_fence(); + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_iter; } + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_A) == size(tiled_mma)); + CUTE_STATIC_ASSERT_V(size(gmem_tiled_copy_B) == size(tiled_mma)); + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > -(DispatchPolicy::Stages-1)) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + + // Advance the tile + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_iter; } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + + cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_ +> +struct CollectiveMma< + MainloopSm80CpAsync< + Stages, + ClusterShape_>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + > +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm80CpAsync< + Stages, + ClusterShape_>; + using TileShape = TileShape_; + // Follow the change in TestSmall: TileShape switch to CtaShape + // In legacy arch, it should be same + using CtaShape_MNK = TileShape; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "CpAsync mainloop must have at least 2 stages in the pipeline."); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + /// Perform a collective-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, // (BLK_M, BLK_K, K_TILES) + TensorB gB, // (BLK_N, BLK_K, K_TILES) + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + using namespace cute; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + // Construct shared memory tiles + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<0>(gA) == size<0>(sA)); // BLK_M + CUTE_STATIC_ASSERT_V(size<1>(gA) == size<1>(sA)); // BLK_K + CUTE_STATIC_ASSERT_V(size<0>(gB) == size<0>(sB)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(gB) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // BLK_K + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + gA = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA); + gB = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_A; + GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // + // PREDICATES + // + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_A.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_B.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // + // PREFETCH + // + + // Clear the smem tiles to account for predicated off loads + clear(tAsA); + clear(tBsB); + + // Start async loads for 0th k-tile, where we take care of the k residue + { + constexpr int k_pipe = 0; + + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_A, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_B, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); + } + } + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // Start async loads for 1st k-tile onwards, no k-residue handling needed + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync + cp_async_fence(); + ++k_tile_iter; + --k_tile_count; + } + + // + // MMA Atom partitioning + // + + // Tile MMA compute thread partitions and allocate accumulators + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCrA = thr_mma.partition_fragment_A(sA(_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB(_,_,0)); // (MMA,MMA_N,MMA_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(src_accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(src_accum)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + + // + // PIPELINED MAIN LOOP + // + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = DispatchPolicy::Stages-1; + + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + { + // Pipeline the outer products with a static for loop. + // + // Note, the for_each() function is required here to ensure `k_block` is of type Int. + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(smem_tiled_copy_A, tCsA_p(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_B, tCsB_p(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Set all predicates to false if we are going to overshoot bounds + if (k_tile_count <= 0) { + clear(tApA); + clear(tBpB); + } + copy_if(gmem_tiled_copy_A, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); + copy_if(gmem_tiled_copy_B, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + ++k_tile_iter; + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? 0 : smem_pipe_read; + } + + // Transform before compute + cute::transform(tCrA(_,_,k_block), TransformA{}); + cute::transform(tCrB(_,_,k_block), TransformB{}); + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tCrA(_,_,k_block), tCrB(_,_,k_block), src_accum); + }); + + } + + cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa5e212d61b06ec8ebe9f8ea39eb505c418f896f --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1380 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule_, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ +public: + // + // Type Aliases + // + using ConversionMode = cutlass::detail::ConversionMode; + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput; + using TileShape = TileShape_; + using KernelSchedule = KernelSchedule_; + +private: + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + // + // Type Aliases + // + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + using StrideScale = cute::Stride, int64_t, int64_t>; + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)) || + (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + ((cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + using SwappedSmemLayoutAtomA = cute::conditional_t; + using SwappedSmemLayoutAtomB = cute::conditional_t; + using SwappedSmemCopyAtomA = cute::conditional_t; + using SwappedSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using SwappedTransformA = cute::conditional_t; + using SwappedTransformB = cute::conditional_t; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static constexpr int NumProducerThreadEvents = 1; + + using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + /// Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalSwappedStrideA{})); + using SmemLayoutB = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalSwappedStrideB{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage { + CUTE_ALIGNAS(SmemAlignmentA) cute::ArrayEngine> smem_A; + CUTE_ALIGNAS(SmemAlignmentB) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + struct TensorMapStorage { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_scale; + cute::TmaDescriptor smem_tensormap_zero; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementScale const** ptr_S = nullptr; + NonVoidStrideScale const* dS{}; + int chunk_size = 0; + ElementZero const** ptr_Z = nullptr; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using LayoutA = decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideA{}, int32_t(0)), InternalSwappedStrideA{})); + using LayoutB = decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideB{}, int32_t(0)), InternalSwappedStrideB{})); + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + void* tensormaps; + SwappedElementA const** ptr_A; + SwappedStrideA ptr_dA; + SwappedElementB const** ptr_B; + SwappedStrideB ptr_dB; + NonVoidElementScale const** ptr_S; + NonVoidStrideScale const* dS; + NonVoidElementZero const** ptr_Z; + int64_t scale_k; + int chunk_size; + int reload_factor = (chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + + if constexpr (SwapAB) { + init_M = get<1>(init_shape); + init_N = get<0>(init_shape); + } + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t mock_L = 1; + SwappedElementA const* ptr_A_first_batch; + SwappedElementB const* ptr_B_first_batch; + SwappedStrideA ptr_dA; + SwappedStrideB ptr_dB; + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + + if constexpr (not SwapAB) { + ptr_A_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_A) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + ptr_B_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_B) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + } + else { + ptr_A_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_B) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + ptr_B_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_A) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + } + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + if constexpr (not SwapAB) { + ptr_dA = args.dA; + ptr_dB = args.dB; + } + else { + ptr_dA = args.dB; + ptr_dB = args.dA; + } + dA = InternalSwappedStrideA{}; + if constexpr (is_layout::value) { + dA = make_layout( + transform_leaf(dA.shape(), [](auto x){ + if constexpr (not is_static_v) { + return static_cast(1); + } else { + return x; + } + }), + dA.stride()); + } + dB = InternalSwappedStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + if constexpr (SwapAB) { + init_M = get<1>(problem_shape_MNK); + init_N = get<0>(problem_shape_MNK); + } + + if constexpr (not SwapAB) { + dA = args.dA; + dB = args.dB; + } + else { + dA = args.dB; + dB = args.dA; + } + ptr_dA = SwappedStrideA{}; + ptr_dB = SwappedStrideB{}; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, detail::get_gmem_layout(make_shape(init_M,init_K,mock_L), dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, detail::get_gmem_layout(make_shape(init_N,init_K,mock_L), dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + void* tensormaps = workspace; + auto args_setup = [&](auto ptr_A, auto ptr_B, int64_t scale_k = 0, int chunk_size = 0, int reload_factor = 1) -> Params { + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tma_load_scale, + tma_load_zero, + tensormaps, + reinterpret_cast(ptr_A), + ptr_dA, + reinterpret_cast(ptr_B), + ptr_dB, + reinterpret_cast(args.ptr_S), + args.dS, + reinterpret_cast(args.ptr_Z), + scale_k, + chunk_size, + reload_factor, + dA, + dB + }; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A) + : args_setup(args.ptr_A, args.ptr_B); + } + else if constexpr (ModeHasScales) { + auto scale_k = ceil_div(init_K, args.chunk_size); + ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); + StrideScale dS{}; + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS)); + tma_load_scale = make_tma_copy( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = reinterpret_cast(args.ptr_Z); + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_layout(make_shape(init_M,scale_k,mock_L), dS)); + tma_load_zero = make_tma_copy( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + + // Calculating workspace size + auto calculate_workspace_size = [SizeOfCuTensorMap, sm_count](uint32_t num_input_tensors) { + return num_input_tensors * SizeOfCuTensorMap * sm_count; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return calculate_workspace_size(2); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale tensormap copies + return calculate_workspace_size(3); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale and zeros tensormap copies + return calculate_workspace_size(4); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in get_workspace_size."); + } + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto get_stride = [](auto stride) { + if constexpr (cute::is_pointer_v>) { + return *stride; + } + else { + return stride; + } + }; + auto dA = get_stride(args.dA); + auto dB = get_stride(args.dB); + implementable = implementable && cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(M,K,L), dA)); + implementable = implementable && cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(N,K,L), dB)); + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + const int scale_k = ceil_div(K, args.chunk_size); + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.chunk_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = Utils::compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = Utils::compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytesExtra = Utils::compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(M,K,mock_L), mainloop_params.dA))); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(N,K,mock_L), mainloop_params.dB))); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } + else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(scale_mn,scale_k,L)); + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(scale_mn,scale_k,L)); + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class... Ts, + class... TMs, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + static_assert(sizeof... (TMs) == 2, "Direct convert needs two tensormaps"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + static_assert(sizeof... (TMs) == 3, "Scaled convert needs three tensormaps"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + static_assert(sizeof... (TMs) == 4, "Scaled and zero convert needs four tensormaps"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = Utils::partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with chunk_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when chunk_size == K. + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_scale.with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_zero.with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SwappedSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SwappedSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = [&]{ + if constexpr (not is_layout::value) { + // Make register tensor with MMA layout + return make_fragment_like(tCrA_mma); + } + else { + // Make register tensor matching smem layout, converter will take care of de-swizzling + return make_tensor_like(tCsA(_,_,_,Int<0>{})); + } + }(); + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SwappedSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = Utils::partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = Utils::retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + constexpr int K_WAIT_MAX = cute::min(K_BLOCK_MAX - 1, 7); + static_assert(K_BLOCK_MAX >= 4, "Consider increasing TileShapeK"); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + if (K_BLOCK_MAX > 1) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + + warpgroup_wait(); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + else { + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_scale = &gmem_tensormap[sm_idx + 2*sm_count]; + cute::TmaDescriptor* tma_desc_zero = &gmem_tensormap[sm_idx + 3*sm_count]; + + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor pS_tensormap = make_tensor(mainloop_params.tma_load_scale.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sS_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_scale), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pS_tensormap), recast(sS_tensormap)); + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor pZ_tensormap = make_tensor(mainloop_params.tma_load_zero.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sZ_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_zero), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pZ_tensormap), recast(sZ_tensormap)); + } + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + + __syncwarp(); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale, tma_desc_zero); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_scale, + mainloop_params.ptr_S[next_batch]); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_zero, + mainloop_params.ptr_Z[next_batch]); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_address."); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = (SwapAB? get<1>(problem_shape_mnkl) : get<0>(problem_shape_mnkl)); + const uint32_t N = (SwapAB? get<0>(problem_shape_mnkl) : get<1>(problem_shape_mnkl)); + const uint32_t K = get<2>(problem_shape_mnkl); + + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_scale = {1,1,1,1,1}; + cute::array prob_stride_scale = {0,0,0,0,0}; + cute::array prob_shape_zero = {1,1,1,1,1}; + cute::array prob_stride_zero = {0,0,0,0,0}; + + SwappedElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, detail::get_gmem_layout(make_shape(M,K,Int<1>{}), mainloop_params.ptr_dA[next_group])); + + SwappedElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, detail::get_gmem_layout(make_shape(N,K,Int<1>{}), mainloop_params.ptr_dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + NonVoidElementScale const* ptr_S = nullptr; + auto scale_k = ceil_div(K, mainloop_params.chunk_size); + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, + prob_shape_scale, prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = nullptr; + auto scale_k = ceil_div(K, mainloop_params.chunk_size); + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, + prob_shape_zero, prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_scale) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_zero) { + stride = (stride * sizeof_bits_v) / 8; + } + + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_scale, + prob_shape_scale, + prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_zero, + prob_shape_zero, + prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_cp_fence_release."); + } + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_fence_acquire."); + } + } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6786cec5b6fc650fbb65e5fb810f32f786359bc6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // Device side kernel params + struct Params { + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + void* tensormaps; + InternalElementA const** ptr_A; + StrideA dA; + InternalElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t init_L = 1; + // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. + InternalElementA const* ptr_A_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_A) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + InternalElementB const* ptr_B_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_B) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + void* tensormaps = workspace; + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tensormaps, + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t init_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,init_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,init_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorMapA, class TensorMapB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(accum); + if (k_tile_count > 0) { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + InternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + InternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..916c6db812ffb9279164e9d477e668b93ac60c2e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,784 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t mma_promotion_interval = 4; + void* tensormaps; + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + auto init_L = get<3>(init_shape); + + // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. + ElementA const* ptr_A_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_A) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + ElementB const* ptr_B_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_B) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + void* tensormaps = workspace; + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + args.mma_promotion_interval, + tensormaps, + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorMapA, class TensorMapB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + accumulation.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + ElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + ElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6e662beb26d411ed6af326b6d5c2420b5b3bb3a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,1245 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/detail/blockwise_scale_layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = cute::tuple_element_t<0,StridePairA_>; + using LayoutSFA = cute::tuple_element_t<1,StridePairA_>; + using InternalStrideA = cute::remove_pointer_t; + using InternalLayoutSFA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = cute::tuple_element_t<0,StridePairB_>; + using LayoutSFB = cute::tuple_element_t<1,StridePairB_>; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = size<0,0>(InternalLayoutSFA{}); + static constexpr int ScaleGranularityN = size<0,0>(InternalLayoutSFB{}); + static constexpr int ScaleGranularityK = size<1,0>(InternalLayoutSFA{}); + + static_assert(size<2>(TileShape{}) % ScaleGranularityK == 0); + static_assert(ScaleGranularityK % size<2>(typename TiledMma::AtomShape_MNK{}) == 0); + + static constexpr int ScalePromotionInterval = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); + static_assert(ScalePromotionInterval % 4 == 0, "ScalePromotionInterval must be a multiple of 4."); + static_assert(ScalePromotionInterval >= size<2>(TileShape{}) / tile_size<2>(TiledMma{}), + "ScalePromotionInterval must be greater than or equal to the number of stages of the MMA atom."); + static_assert(ScalePromotionInterval % (size<2>(TileShape{}) / tile_size<2>(TiledMma{})) == 0, + "ScalePromotionInterval must be a multiple of the number of stages of the MMA atom."); + + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); + + static constexpr bool MMajorSFA = size<0,1>(InternalLayoutSFA{}.stride()) == 1; + static constexpr bool NMajorSFB = size<0,1>(InternalLayoutSFB{}.stride()) == 1; + + using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + MMajorSFA ? cute::GMMA::Major::MN : cute::GMMA::Major::K, + NMajorSFB ? cute::GMMA::Major::MN : cute::GMMA::Major::K>; + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(TileShape{})); + using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(TileShape{})); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + // we can have partial tiles in M or N, so don't vectorize those loads + using CopyAtomSFA = Copy_Atom, ElementBlockScale>; + using CopyAtomSFB = Copy_Atom, ElementBlockScale>; + + static constexpr int AlignmentSFA = 1; + static constexpr int AlignmentSFB = 1; + + // Block scaling smem layout + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_SFA; + cute::array_aligned> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementBlockScale const** ptr_SFA; + LayoutSFA layout_SFA; + ElementBlockScale const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + void* tensormaps; + InternalElementA const** ptr_A; + StrideA dA; + InternalElementB const** ptr_B; + StrideB dB; + // Block scaling factors for A and B + ElementBlockScale const** ptr_SFA; + LayoutSFA layout_SFA; + ElementBlockScale const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t init_L = 1; + // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. + InternalElementA const* ptr_A_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_A) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + InternalElementB const* ptr_B_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_B) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b)); + auto tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + auto tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + void* tensormaps = workspace; + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tensormaps, + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + args.ptr_SFA, + args.layout_SFA, + args.ptr_SFB, + args.layout_SFB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + bool implementable = true; + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params, + ElementBlockScale const* ptr_SFA = nullptr, + ElementBlockScale const* ptr_SFB = nullptr + ) const { + + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t init_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,init_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,init_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Make the tiled views of scale tensors + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(ptr_SFA), + ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, init_L))); // (scale_m,k,l) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(ptr_SFB), + ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, init_L))); // (scale_n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorMapA, class TensorMapB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), + SmemLayoutSFA{}); // (BLK_M,BLK_K,P) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), + SmemLayoutSFB{}); // (BLK_N,BLK_K,P) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor gSFA_k = gSFA_mkl(_,_,m_coord,_,l_coord); + Tensor gSFB_k = gSFB_nkl(_,_,n_coord,_,l_coord); + + TiledCopy scale_copy_a = make_tiled_copy(CopyAtomSFA{}, Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(CopyAtomSFB{}, Layout>{}, Layout>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(_0{}); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(_0{}); + + Tensor tSFAgSFA_k = thr_scale_copy_a.partition_S(gSFA_k); + Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tSFBgSFB_k = thr_scale_copy_b.partition_S(gSFB_k); + Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_auxiliary( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), + SmemLayoutSFA{}); // (BLK_M,BLK_K,P) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), + SmemLayoutSFB{}); // (BLK_N,BLK_K,P) + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + Layout layoutSFA = mSFA_mkl.layout(); + Layout layoutSFB = mSFB_nkl.layout(); + + Tensor iSFA_mkl = make_identity_tensor(shape(layoutSFA)); // (m,k,l) + Tensor iSFB_nkl = make_identity_tensor(shape(layoutSFB)); // (n,k,l) + + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor cSFA_mkl = local_tile(iSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor cSFB_nkl = local_tile(iSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor gSFA_k = gSFA_mkl(_,_,m_coord,_,l_coord); + Tensor cSFA_k = cSFA_mkl(_,_,m_coord,_,l_coord); + Tensor gSFB_k = gSFB_nkl(_,_,n_coord,_,l_coord); + Tensor cSFB_k = cSFB_nkl(_,_,n_coord,_,l_coord); + + TiledCopy scale_copy_a = make_tiled_copy(CopyAtomSFA{}, Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(CopyAtomSFB{}, Layout>{}, Layout>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); + + Tensor tSFAgSFA_k = thr_scale_copy_a.partition_S(gSFA_k); + Tensor tSFAcSFA_k = thr_scale_copy_a.partition_S(cSFA_k); + Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tSFBgSFB_k = thr_scale_copy_b.partition_S(gSFB_k); + Tensor tSFBcSFB_k = thr_scale_copy_b.partition_S(cSFB_k); + Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); + + Tensor tSFApSFA = make_tensor(shape(filter_zeros(tSFAsSFA(_,_,_,_0{})))); // (CPY,CPY_M,CPY_K) + Tensor tSFBpSFB = make_tensor(shape(filter_zeros(tSFBsSFB(_,_,_,_0{})))); // (CPY,CPY_N,CPY_K) + + auto SFA_shape = shape(layoutSFA); + auto SFB_shape = shape(layoutSFB); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB + bool load_sfa = thread_idx < ScaleMsPerTile; + Tensor tSFAcSFA = tSFAcSFA_k(_,_,_,*k_tile_iter); + Tensor tSFAcSFA_compact = filter_zeros(tSFAcSFA); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSFApSFA); ++i) { + tSFApSFA(i) = load_sfa && elem_less(tSFAcSFA_compact(i), SFA_shape); + } + + bool load_sfb = thread_idx < ScaleNsPerTile; + Tensor tSFBcSFB = tSFBcSFB_k(_,_,_,*k_tile_iter); + Tensor tSFBcSFB_compact = filter_zeros(tSFBcSFB); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSFBpSFB); ++i) { + tSFBpSFB(i) = load_sfb && elem_less(tSFBcSFB_compact(i), SFB_shape); + } + + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + + // Copy scale tensors from global memory to shared memory + copy_if(scale_copy_a, tSFApSFA, filter_zeros(tSFAgSFA_k(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,write_stage))); + copy_if(scale_copy_b, tSFBpSFB, filter_zeros(tSFBgSFB_k(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,write_stage))); + + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor); + } + } + + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor1, + class ScaleFactor2 + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor1, scaleFactor2); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor1, scaleFactor2); + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), make_layout( + make_shape(shape<0>(SmemLayoutSFA{}), + get<1>(TileShape{}), + make_shape(shape<1>(SmemLayoutSFA{}), + shape<2>(SmemLayoutSFA{}))), + make_stride(stride<0>(SmemLayoutSFA{}), _0{}, + make_stride(stride<1>(SmemLayoutSFA{}), + stride<2>(SmemLayoutSFA{}))) + )); // (BLK_M,BLK_N,(BLK_K,P)) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), make_layout( + make_shape(get<0>(TileShape{}), + shape<0>(SmemLayoutSFB{}), + make_shape(shape<1>(SmemLayoutSFB{}), + shape<2>(SmemLayoutSFB{}))), + make_stride(_0{}, + stride<0>(SmemLayoutSFB{}), + make_stride(stride<1>(SmemLayoutSFB{}), + stride<2>(SmemLayoutSFB{}))) + )); // (BLK_M,BLK_N,(BLK_K,P)) + + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsSFA = tiled_mma.get_slice(thread_idx).partition_C(sSFA); // (MMA,MMA_M,MMA_N,(MMA_K,PIPE)) + Tensor tCsSFB = tiled_mma.get_slice(thread_idx).partition_C(sSFB); // (MMA,MMA_M,MMA_N,(MMA_K,PIPE)) + + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + // Since scale factors always broadcast across MMA_K we slice that away + Tensor tCrSFA = make_tensor_like(tCsSFA(_, _, _, _0{})); // (MMA,MMA_M,MMA_N) + Tensor tCrSFB = make_tensor_like(tCsSFB(_, _, _, _0{})); // (MMA,MMA_M,MMA_N) + + // Prologue GMMAs + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + + if (k_tile_count > 0) { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + // Load per block scale values from shared memory to registers + copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA); + copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB); + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + warpgroup_fence_operand(accumulation()); + + + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { + filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { + filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + } + } + + warpgroup_wait<0>(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFA, tCrSFB); + } + } + + warpgroup_fence_operand(accumulation()); + + // Mainloop GMMAs + k_tile_count--; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) + { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) + copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); + copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); + + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { + filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { + filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + } + } + + warpgroup_wait<0>(); + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFA, tCrSFB); + } + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_release; + } + + if (k_tile_count > 0) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) + copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); + copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); + + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { + filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { + filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + } + } + warpgroup_wait<0>(); + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + + // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFA, tCrSFB); + } + } + if constexpr (ScalePromotionInterval != 4) { + // residues only exists when granularity is not the finnest + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + accumulation.scale_residue_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrSFA, tCrSFB); + } + } + + warpgroup_fence_operand(accumulation()); + + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + if (k_tile_count > 0) { + // The pipeline is not released in the first iteration + smem_pipe_release.advance(k_tile_count - 1); + pipeline.consumer_release(smem_pipe_release); + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + InternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + InternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + [[maybe_unused]] InputTensors const& input_tensors, + Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + + if constexpr (IsGroupedGemmKernel) { + return load_init( + problem_shape_mnkl, + mainloop_params, + mainloop_params.ptr_SFA[next_batch], + mainloop_params.ptr_SFB[next_batch] + ); + } else { + auto [gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl] = input_tensors; + + auto scaleA_layout = mScaleA_mkl.layout(); + auto scaleB_layout = mScaleB_nkl.layout(); + + mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA[next_batch]), scaleA_layout); // (m,ScaleMsPerTile,k,l) + mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB[next_batch]), scaleB_layout); // (n,ScaleNsPerTile,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4289bc816b057416f25a7f155a0e72ed5088b034 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -0,0 +1,676 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape_, + class TileShape_, + class KernelSchedule, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90CpAsyncGmmaRmemAWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90CpAsyncGmmaRmemAWarpSpecialized; + using TileShape = TileShape_; + using ClusterShape = ClusterShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8, Fp8 WGMMA) + static constexpr bool IsLayoutAkBmn = + cute::is_same_v, layout::RowMajor> && + cute::is_same_v, layout::RowMajor>; + + static constexpr bool IsInputSizeTwoBytes = sizeof(ElementA) == 2 && sizeof(ElementB) == 2; + static constexpr bool SwapAB = !IsInputSizeTwoBytes && IsLayoutAkBmn; + using InternalGmemTiledCopyA = cute::conditional_t; + using InternalGmemTiledCopyB = cute::conditional_t; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineAsync; + using PipelineState = typename MainloopPipeline::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, layout::ColumnMajor> && + cute::is_same_v, layout::RowMajor>; + static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; + using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( + 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, + InternalElementB{}, cute::bool_constant{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + + using GmmaSmemLayoutAtomB = decltype(transform::collective::detail::gmma_smem_transpose_or_passthrough< + TransposeB, InternalSmemLayoutAtomB, InternalElementB>()); + + // SmemLayoutB for GMMA is different from SmemLayoutB for TMA if TransposeB + using GmmaSmemLayoutB = decltype(tile_to_shape( + GmmaSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); + static_assert(TransposeB xor (cute::is_same_v), + "Should be same layout if not TransposeB."); + static_assert(!TransposeB || (cutlass::bits_to_bytes(size<1>(SmemLayoutB{}) * sizeof_bits::value)) == 128, + "SmemLayoutB K must be 128bytes to be transposed."); + static_assert(!transform::collective::detail::use_universal_transposition(), + "Warp specialized ARF kernels have not supported universal B transposition yet."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<256, _0> { + cute::array_aligned, 256> smem_A; + cute::array_aligned, 256> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + InternalElementA const* ptr_A = nullptr; + InternalStrideA dA{}; + InternalElementB const* ptr_B = nullptr; + InternalStrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + if constexpr (not SwapAB) { + return { + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + else { + return { + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_A), + args.dA + }; + } + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, + class TensorB, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + load( + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_in, + TensorB const& gB_in, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + TensorStorage& shared_tensors) + { + using namespace cute; + + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA_in); + Tensor gB = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB_in); + + // Partition the copying of A and B tiles across the threads + InternalGmemTiledCopyA gmem_tiled_copy_a; + InternalGmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // 0-th stage with predication on k to account for residue + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + + ++k_tile_iter; + --k_tile_count; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + // Issue the epilogue waits + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) + { + using namespace cute; + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_M,BLK_K,PIPE) + + // If TransposeB, GMMA will read from transposed B layout SMEM + Tensor gmma_sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCsA = mma_thread_slice.partition_A(sA); + Tensor tCrA = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(gmma_sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + + + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, + cute::bool_constant{}); + + warpgroup_fence_operand(accum); + // first k tile + { + pipeline.consumer_wait(smem_pipe_read); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + + bool skip_wait = (pipeline.consumer_try_wait(smem_pipe_read) == BarrierStatus::WaitDone); + + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,_,0,read_stage), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, read_stage, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + transpose.synchronize(); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + warpgroup_wait<2>(); + + + if (k_tile_count - 1 > 0) { + if (!skip_wait) { + pipeline.consumer_wait(smem_pipe_read); + } + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + --k_tile_count; + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + bool skip_wait = (pipeline.consumer_try_wait(smem_pipe_read) == BarrierStatus::WaitDone); + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + if (k_block == size<2>(tCrA) - 1) { + if (!skip_wait) { + pipeline.consumer_wait(smem_pipe_read); + } + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } else { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + // transpose B operand in SMEM + if (k_block < 2) { + transpose.synchronize(k_block); // make transpose of k_block available + } + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + if (k_tile_count > 0) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + if (k_block < 2) { + transpose.synchronize(k_block); // make k_block transpose available + } + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fbbe971c7f338a26d7929fde565288eae11ffa54 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -0,0 +1,508 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape_, + class TileShape_, + class KernelSchedule, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90CpAsyncGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90CpAsyncGmmaWarpSpecialized; + using TileShape = TileShape_; + using ClusterShape = ClusterShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineAsync; + using PipelineState = typename MainloopPipeline::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, + class TensorB, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + load( + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_in, + TensorB const& gB_in, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + TensorStorage& shared_tensors) + { + using namespace cute; + + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA_in); + Tensor gB = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB_in); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // 0-th stage with predication on k to account for residue + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + // Issue the epilogue waits + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) + { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + assert(k_tile_count >= 1); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(accum); + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + warpgroup_arrive(); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) { + + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + warpgroup_arrive(); + + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8e054370098985ad991f9d0db20f059195e4b96 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -0,0 +1,754 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaRmemAWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only + // (e.g. tf32, Fp32, Int8, Fp8 WGMMA) + static constexpr bool IsLayoutAkBmn = + cute::is_same_v, layout::RowMajor> && + cute::is_same_v, layout::RowMajor>; + + static constexpr bool IsInputSizeTwoBytes = sizeof(ElementA) == 2 && sizeof(ElementB) == 2; + static constexpr bool SwapAB = !IsInputSizeTwoBytes && IsLayoutAkBmn; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, layout::ColumnMajor> && + cute::is_same_v, layout::RowMajor>; + static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; + using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( + 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, + InternalElementB{}, cute::bool_constant{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutAtomB = decltype(transform::collective::detail::gmma_smem_transpose_or_passthrough< + TransposeB, InternalSmemLayoutAtomB, InternalElementB>()); + + // SmemLayoutB for GMMA is different from SmemLayoutB for TMA if TransposeB + using GmmaSmemLayoutB = decltype(tile_to_shape( + GmmaSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); + static_assert(TransposeB xor (cute::is_same_v), + "Should be same layout if not TransposeB."); + static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits::value))) == 128, + "SmemLayoutB K must be 128bytes to be transposed."); + + static constexpr bool uses_universal_transposition() { + if constexpr (TransposeB) { + return transform::collective::detail::use_universal_transposition(); + } + else { + return false; + } + } + + static_assert(!uses_universal_transposition(), + "Warp specialized ARF kernels have not supported universal B transposition yet."); + + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct { + cute::array_aligned, SmemAlignmentA> smem_A; + cute::array_aligned, SmemAlignmentB> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + InternalElementA const* ptr_A; + InternalStrideA dA; + InternalElementB const* ptr_B; + InternalStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } + else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) ; + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_M,BLK_K,PIPE) + + // If TransposeB, GMMA will read from transposed B layout SMEM + Tensor gmma_sB_position_dependent = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor gmma_sB = as_position_independent_swizzle_tensor(gmma_sB_position_dependent); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCsA = mma_thread_slice.partition_A(sA); + Tensor tCrA = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(gmma_sB_position_dependent); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + + + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + Tensor tCsA_copy_view = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCsA_copy_view) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA_copy_view) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(tCrA) > _2{}, "RS loops require more than 2 MMA k-iterations for correctness."); + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, + cute::bool_constant{}); + + warpgroup_fence_operand(accum); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,read_stage), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, read_stage, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + transpose.synchronize(k_block); + transpose(sB, gmma_sB, read_stage, k_block + 1); + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + if(k_block == 0) { + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + } + + warpgroup_wait<2>(); + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + warpgroup_commit_batch(); + --k_tile_count; + if(k_tile_count == 0) { + return; + } + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + warpgroup_wait<2>(); + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + if (k_block == size<2>(tCrA) - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } + else { + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + // transpose B operand in SMEM + transpose.synchronize(k_block); // make transpose of k_block available + transpose(sB, gmma_sB, read_stage, k_block + 1); + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + transpose.synchronize(k_block); // make k_block transpose available + transpose(sB, gmma_sB, read_stage, k_block + 1); + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + warpgroup_commit_batch(); + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2558350ce38664f8e412deef8b567d558f8a775c --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1032 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template < + int Stages, + class ClusterShape, + class KernelSchedule_, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ +public: + + // + // Type Aliases + // + using ConversionMode = cutlass::detail::ConversionMode; + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + using TileShape = TileShape_; + using KernelSchedule = KernelSchedule_; + +private: + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using StrideB = StrideB_; + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the code to compile when the scale is void. + using NonVoidStrideScale = cute::conditional_t< + cute::is_void_v, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value)) || + (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value)), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + ((cutlass::gemm::detail::is_k_major() || is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value)), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + // Scale layout atom set after swapping. + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using SwappedSmemLayoutAtomA = cute::conditional_t; + using SwappedSmemLayoutAtomB = cute::conditional_t; + using SwappedSmemCopyAtomA = cute::conditional_t; + using SwappedSmemCopyAtomB = cute::conditional_t; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using SwappedTransformA = cute::conditional_t; + using SwappedTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync< + DispatchPolicy::Stages>; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + + using SmemLayoutA = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomA{}, select<0,2>(TileShape{}), SwappedStrideA{})); + using SmemLayoutB = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomB{}, select<1,2>(TileShape{}), SwappedStrideB{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage { + CUTE_ALIGNAS(SmemAlignmentA) cute::ArrayEngine> smem_A; + CUTE_ALIGNAS(SmemAlignmentB) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + public: + + // Assumption: StrideA is congruent with Problem_MK + using LayoutA = decltype(detail::get_gmem_layout(repeat_like(SwappedStrideA{}, int32_t(0)), SwappedStrideA{})); + using LayoutB = decltype(detail::get_gmem_layout(repeat_like(SwappedStrideB{}, int32_t(0)), SwappedStrideB{})); + + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + SwappedStrideA dA; + SwappedStrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + SwappedElementA const* ptr_A; + SwappedStrideA dA; + SwappedElementB const* ptr_B; + SwappedStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } + else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(detail::get_logical_ptr(ptr_A), detail::get_gmem_layout(make_shape(M,K,L), dA)); + Tensor tensor_b = make_tensor(detail::get_logical_ptr(ptr_B), detail::get_gmem_layout(make_shape(N,K,L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along M mode for this N load, if any + + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB }; + } + else if constexpr (ModeHasScales) { + auto scale_k = ceil_div(K, args.group_size); + ElementScale const* ptr_S = args.ptr_S; + StrideScale dS = args.dS; + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS)); + tma_load_scale = make_tma_copy( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB }; + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS)); + tma_load_zero = make_tma_copy( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB }; + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + bool check_aligned_A = cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(M,K,L), args.dA)); + + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + bool check_aligned_B = cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(N,K,L), args.dB)); + + bool check_aligned_S = true; + bool check_aligned_Z = true; + bool check_mode_args = true; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + check_mode_args = check_mode_args && (args.ptr_S == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + const int scale_k = ceil_div(K, args.group_size); + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + check_aligned_S = cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), args.dS); + check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + check_mode_args = check_mode_args && args.group_size != 0; + check_mode_args = check_mode_args && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + check_aligned_Z = cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), args.dS); + check_mode_args = check_mode_args && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!check_mode_args) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n"); + } + if (!check_aligned_A) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_B) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_S) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_Z) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n"); + } + + return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = Utils::compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = Utils::compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytesExtra = Utils::compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(M,K,L), mainloop_params.dA))); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(N,K,L), mainloop_params.dB))); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } + else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = Utils::partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + int const scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when group_size == K. + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + // Issue the epilogue waits + if (cute::elect_one_sync()) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SwappedSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SwappedSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + + Tensor tCrA_load = [&]{ + if constexpr (not is_layout::value) { + // Make register tensor with MMA layout + return make_fragment_like(tCrA_mma); + } + else { + // Make register tensor matching smem layout, converter will take care of de-swizzling + return make_tensor_like(tCsA(_,_,_,Int<0>{})); + } + }(); + + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SwappedSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = Utils::partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = Utils::retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + constexpr int K_WAIT_MAX = cute::min(K_BLOCK_MAX - 1, 7); + static_assert(K_BLOCK_MAX >= 4, "Consider increasing TileShapeK"); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + if (K_BLOCK_MAX > 1) { // prefetch next block + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + if (K_BLOCK_MAX > 1) { // prefetch next block + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + } + warpgroup_wait(); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + if (K_BLOCK_MAX > 1) { // prefetch next block + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + } + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + else { + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp new file mode 100644 index 0000000000000000000000000000000000000000..228c25894dbcf8aac3eedd4dd54e07609b5eb365 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -0,0 +1,538 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ClusterShape, + int PipelineAsyncMmaStages, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmma, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmma; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + static constexpr int ThreadCount = CUTE_STATIC_V(size(TiledMma{})); + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + struct SharedStorage { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TMA_LOAD_A, + class TensorB, class TMA_LOAD_B, + class FrgTensorC, + class KTileIterator + > + CUTLASS_DEVICE void + operator() ( + TensorA const& gA, TMA_LOAD_A& tma_load_a, + TensorB const& gB, TMA_LOAD_B& tma_load_b, + FrgTensorC& accum, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + char* shared_memory, + Params const& mainloop_params) + { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + SharedStorage& storage = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // + // Prepare TMA membars and PREFETCH + // + + // Number of pipelined k-tiles in smem + constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + + // NOTE: Another parameter: Partition the pipeline between active MMAs and active TMAs + // Tunable via the dispatch policy to tollerate latencies evenly across the math and compute stages + // K_PIPE_MMAS: The max number of active MMA pipes at beginning of every loop + // K_PIPE_TMAS: The max number of active TMA pipes at beginning of every loop (geq 1) + constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + constexpr int K_PIPE_TMAS = K_PIPE_MAX - K_PIPE_MMAS; + static_assert(0 <= K_PIPE_MMAS && K_PIPE_MMAS < K_PIPE_MAX); + static_assert(0 < K_PIPE_TMAS && K_PIPE_TMAS <= K_PIPE_MAX); + + static_assert(K_PIPE_MMAS < K_PIPE_MAX - 1); + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr uint32_t TmaTransactionBytes = static_cast( + cutlass::bits_to_bytes(size<0>(sA) * size<1>(sA) * sizeof_bits::value) + + cutlass::bits_to_bytes(size<0>(sB) * size<1>(sB) * sizeof_bits::value)); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + + PipelineParams params; + params.transaction_bytes = TmaTransactionBytes; + params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + params.is_leader = warp_group_thread_idx == 0; + params.num_consumers = NumThreadsPerWarpGroup; + + MainloopPipeline pipeline(storage.pipeline_storage, params, ClusterShape{}); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read; + PipelineState smem_pipe_release; + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + // Set predicate for the lowest lane_id in the warp + int lane_predicate = cute::elect_one_sync(); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + // Keep a copy to know when to stop issuing loads + int k_tile_count_tma = k_tile_count; + + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0 && lane_predicate == 1) { + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Issue the prologue loads + int prologue_tma_count = min(K_PIPE_MAX, k_tile_count); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < prologue_tma_count; ++stage) { + pipeline.producer_acquire(smem_pipe_write); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,stage)); + ++k_tile_iter; + ++smem_pipe_write; + } + k_tile_count_tma -= prologue_tma_count; + } + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + __syncthreads(); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + // Prologue MMAs + assert(k_tile_count >= 1); + { + // WAIT on smem_pipe_read until it's data is available + pipeline.consumer_wait(smem_pipe_read); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,smem_pipe_read.index()), tCrB(_,_,k_block,smem_pipe_read.index()), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + ++smem_pipe_read; + --k_tile_count; + } + + CUTLASS_PRAGMA_UNROLL + for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count) - 1; + prologue_mma_count > 0; --prologue_mma_count) + { + // WAIT on smem_pipe_read until it's data is available + pipeline.consumer_wait(smem_pipe_read); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); + warpgroup_commit_batch(); + ++smem_pipe_read; + --k_tile_count; + } + warpgroup_fence_operand(accum); + + // + // PIPELINED MAIN LOOP + // + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until data is available + pipeline.consumer_wait(smem_pipe_read); + + // + // Compute on k_tile + // + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK wr stage, done _computing_ on it + + // + // Copy gmem to smem for *k_tile_iter + // + + // Do Acquire & Load only if needed - helps with both performance and also corner case illegal barrier-ops + if (warp_idx == 0 && lane_predicate == 1 && (k_tile_count_tma > 0) ) { + pipeline.producer_acquire(smem_pipe_write); // LOCK wr stage, for _writing_ + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write.index())); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write.index())); + ++smem_pipe_write; + ++k_tile_iter; + --k_tile_count_tma; + } + + // Advance consumer pipeline + ++smem_pipe_read; + ++smem_pipe_release; + } + + // Wait on all GMMAs + warpgroup_wait<0>(); + warpgroup_fence_operand(accum); + + // Workaround for ensuring Smem destruction doesn't happen accidentally + if constexpr (size(typename DispatchPolicy::ClusterShape{}) > 1) { + cute::cluster_arrive(); + cute::cluster_wait(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0e64bad5d2e406156f2532fc5420a662ba3d0687 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,584 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + assert(k_tile_count >= 1); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(accum); + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7ea65a6fdbecf31f82a8a51fc390137dd1b16c6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,587 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.mma_promotion_interval + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + /* MMA promotion interval should be a multiple of the number of MMA instructions issued by each mainloop iteration. */ + implementable = implementable && (args.mma_promotion_interval % (size<2>(TileShape{})() / TiledMma().template tile_size_mnk<2>()()) == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + accumulation.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..48ddf7a0d7b76350911e674f2d1cb3bd4b661921 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,1102 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/detail/blockwise_scale_layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = cute::tuple_element_t<0,StridePairA_>; + using LayoutSFA = cute::tuple_element_t<1,StridePairA_>; + using ElementB = ElementB_; + using StrideB = cute::tuple_element_t<0,StridePairB_>; + using LayoutSFB = cute::tuple_element_t<1,StridePairB_>; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScaleTMA = cute::SM90_TMA_LOAD; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{}); + static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{}); + static constexpr int ScaleGranularityK = size<1,0>(LayoutSFA{}); + + static_assert(size<2>(TileShape{}) % ScaleGranularityK == 0); + static_assert(ScaleGranularityK % size<2>(typename TiledMma::AtomShape_MNK{}) == 0); + + static constexpr int ScalePromotionInterval = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); + static_assert(ScalePromotionInterval % 4 == 0, "ScalePromotionInterval must be a multiple of 4."); + static_assert(ScalePromotionInterval >= size<2>(TileShape{}) / tile_size<2>(TiledMma{}), + "ScalePromotionInterval must be greater than or equal to the number of stages of the MMA atom."); + static_assert(ScalePromotionInterval % (size<2>(TileShape{}) / tile_size<2>(TiledMma{})) == 0, + "ScalePromotionInterval must be a multiple of the number of stages of the MMA atom."); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + static constexpr bool MMajorSFA = size<0,1>(LayoutSFA{}.stride()) == 1; + static constexpr bool NMajorSFB = size<0,1>(LayoutSFB{}.stride()) == 1; + + static constexpr int ScaleTmaThreshold = 32; + static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold && MMajorSFA; + static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold && NMajorSFB; + // Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`) + static constexpr int NumProducerThreadEvents = ((IsTmaLoadSFA && IsTmaLoadSFB)? 1 : 33); + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); + + using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + MMajorSFA ? cute::GMMA::Major::MN : cute::GMMA::Major::K, + NMajorSFB ? cute::GMMA::Major::MN : cute::GMMA::Major::K>; + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(TileShape{})); + using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(TileShape{})); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + // we can have partial tiles in M or N, so don't vectorize those loads + using CopyAtomSFA = Copy_Atom, ElementBlockScale>; + using CopyAtomSFB = Copy_Atom, ElementBlockScale>; + + static constexpr int AlignmentSFA = IsTmaLoadSFA ? 128 / cutlass::sizeof_bits::value : 1; + static constexpr int AlignmentSFB = IsTmaLoadSFB ? 128 / cutlass::sizeof_bits::value : 1; + + // Block scaling smem layout + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; // TILE_M x PIPE_K + cute::array_aligned> smem_B; // TILE_N x PIPE_K + CUTE_ALIGNAS(128) cute::array> smem_SFA; // ScaleMsPerTile x PIPE_K + CUTE_ALIGNAS(128) cute::array> smem_SFB; // ScaleNsPerTile x PIPE_K + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + ElementBlockScale const* ptr_SFA; + LayoutSFA layout_SFA; + ElementBlockScale const* ptr_SFB; + LayoutSFB layout_SFB; + }; + + // Device side kernel params + struct Params { + static auto getTmaSFA() { + if constexpr (IsTmaLoadSFA) { + return make_tma_copy( + GmemTiledCopyScaleTMA{}, + make_tensor(static_cast(nullptr), filter_zeros(LayoutSFA{})), + filter_zeros(SmemLayoutSFA{}(_,_,_0{})), + Shape, Int<1>>{}, + _1{}); + } + else { + return nullptr; + } + } + static auto getTmaSFB() { + if constexpr (IsTmaLoadSFB) { + return make_tma_copy( + GmemTiledCopyScaleTMA{}, + make_tensor(static_cast(nullptr), filter_zeros(LayoutSFB{})), + filter_zeros(SmemLayoutSFB{}(_,_,_0{})), + Shape, Int<1>>{}, + _1{}); + } + else { + return nullptr; + } + } + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_0{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_0{}), + TileShape{}, + ClusterShape{})); + // NOTE: Does make_tma_copy supports 0 stride? + using TMA_SFA = decltype(getTmaSFA()); + using TMA_SFB = decltype(getTmaSFB()); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementBlockScale const* ptr_SFA; + ElementBlockScale const* ptr_SFB; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + auto ptr_SFA = reinterpret_cast(args.ptr_SFA); + auto ptr_SFB = reinterpret_cast(args.ptr_SFB); + + Tensor tensor_sfa = make_tensor(ptr_SFA, filter_zeros(args.layout_SFA)); + Tensor tensor_sfb = make_tensor(ptr_SFB, filter_zeros(args.layout_SFB)); + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_SFA tma_load_sfa{}; + if constexpr (IsTmaLoadSFA) { + tma_load_sfa = make_tma_copy( + GmemTiledCopyScaleTMA{}, + tensor_sfa, + filter_zeros(SmemLayoutSFA{})(_,_,cute::Int<0>{}), + Shape, Int<1>>{}, + _1{}); + } + typename Params::TMA_SFB tma_load_sfb{}; + if constexpr (IsTmaLoadSFB) { + tma_load_sfb = make_tma_copy( + GmemTiledCopyScaleTMA{}, + tensor_sfb, + filter_zeros(SmemLayoutSFB{})(_,_,cute::Int<0>{}), + Shape, Int<1>>{}, + _1{}); + } + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes_sfa = TmaTransactionBytesSFA; + uint32_t transaction_bytes_sfb = TmaTransactionBytesSFB; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk + transaction_bytes_sfa + transaction_bytes_sfb; + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_SFA, + args.ptr_SFB, + args.layout_SFA, + args.layout_SFB, + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + if (!cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{})) { + implementable = false; + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load tensor A.\n"); + } + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + if (!cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{})) { + implementable = false; + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load tensor B.\n"); + } + constexpr int min_tma_aligned_elements_S = tma_alignment_bits / cutlass::sizeof_bits::value; + if (IsTmaLoadSFA && !cutlass::detail::check_alignment(args.layout_SFA)) { + implementable = false; + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale A.\n"); + } + if (IsTmaLoadSFB && !cutlass::detail::check_alignment(args.layout_SFB)) { + implementable = false; + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesSFA = + (IsTmaLoadSFA? cutlass::bits_to_bytes(ScaleMsPerTile * static_cast(sizeof_bits::value)): 0); + static constexpr uint32_t TmaTransactionBytesSFB = + (IsTmaLoadSFB? cutlass::bits_to_bytes(ScaleNsPerTile * static_cast(sizeof_bits::value)): 0); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesSFA + TmaTransactionBytesSFB; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + if constexpr (IsTmaLoadSFA) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfa.get_tma_descriptor()); + } + if constexpr (IsTmaLoadSFB) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfb.get_tma_descriptor()); + } + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Note that mSFA_mkl and mSFB_nkl are already blocked tiled in the `m` host and + // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mSFA_mkl and mSFB_nkl. + auto mSFA_mkl = [&]() { + if constexpr (IsTmaLoadSFA) { + return mainloop_params.tma_load_sfa.get_tma_tensor(shape(filter_zeros(mainloop_params.layout_SFA))); + } + else { + return make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (scale_m,k,l) + } + }(); + auto mSFB_nkl = [&]() { + if constexpr (IsTmaLoadSFB) { + return mainloop_params.tma_load_sfb.get_tma_tensor(shape(filter_zeros(mainloop_params.layout_SFB))); + } + else { + return make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (scale_n,k,l) + } + }(); + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), filter_zeros(SmemLayoutSFA{})); // (ScaleMsPerTile,PIPE) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), filter_zeros(SmemLayoutSFB{})); // (ScaleNsPerTile,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFA = local_tile( + mSFA_mkl, make_tile(Int{}, Int<1>{}), + make_coord(m_coord,_,l_coord)); + Tensor gSFB = local_tile( + mSFB_nkl, make_tile(Int{}, Int<1>{}), + make_coord(n_coord,_,l_coord)); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + auto [tAgA_SFA, tAsA_SFA] = [&]() { + if constexpr (IsTmaLoadSFA) { + auto block_tma_sfa = mainloop_params.tma_load_sfa.get_slice(cluster_local_block_id.y); + Tensor tAgA_SFA_ = block_tma_sfa.partition_S(gSFA); + Tensor tAsA_SFA_ = block_tma_sfa.partition_D(sSFA); + return cute::make_tuple(tAgA_SFA_, tAsA_SFA_); + } + else { + return cute::make_tuple(0, 0); + } + }(); + auto [tBgB_SFB, tBsB_SFB] = [&]() { + if constexpr (IsTmaLoadSFB) { + auto block_tma_sfb = mainloop_params.tma_load_sfb.get_slice(cluster_local_block_id.y); + Tensor tBgB_SFB_ = block_tma_sfb.partition_S(gSFB); + Tensor tBsB_SFB_ = block_tma_sfb.partition_D(sSFB); + return cute::make_tuple(tBgB_SFB_, tBsB_SFB_); + } + else { + return cute::make_tuple(0, 0); + } + }(); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_sf = 0; + + // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + int write_stage = smem_pipe_write.index(); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + // Copy operands A and B from global memory to shared memory + if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + // Copy scale tensors from global memory to shared memory + if constexpr (IsTmaLoadSFA) { + if (lane_predicate) { + copy(mainloop_params.tma_load_sfa.with(*tma_barrier, mcast_mask_sf), tAgA_SFA(_,_,_,*k_tile_iter), tAsA_SFA(_,_,_,write_stage)); + } + } + if constexpr (IsTmaLoadSFB) { + if (lane_predicate) { + copy(mainloop_params.tma_load_sfb.with(*tma_barrier, mcast_mask_sf), tBgB_SFB(_,_,_,*k_tile_iter), tBsB_SFB(_,_,_,write_stage)); + } + } + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + template < + class TensorA, class TensorB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_auxiliary( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (ScaleMsPerTile,k) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (ScaleNsPerTile,k) + + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + + Tensor iSFA_mkl = make_identity_tensor(shape(mainloop_params.layout_SFA)); // (m,k,l) + Tensor iSFB_nkl = make_identity_tensor(shape(mainloop_params.layout_SFB)); // (n,k,l) + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor cSFA_mkl = local_tile(iSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor cSFB_nkl = local_tile(iSFB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor gSFA_k = gSFA_mkl(_,_,m_coord,_,l_coord); + Tensor cSFA_k = cSFA_mkl(_,_,m_coord,_,l_coord); + Tensor gSFB_k = gSFB_nkl(_,_,n_coord,_,l_coord); + Tensor cSFB_k = cSFB_nkl(_,_,n_coord,_,l_coord); + + TiledCopy scale_copy_a = make_tiled_copy(CopyAtomSFA{}, + Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(CopyAtomSFB{}, + Layout>{}, Layout>{}); + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); + + Tensor tSFAgSFA_k = thr_scale_copy_a.partition_S(gSFA_k); + Tensor tSFAcSFA_k = thr_scale_copy_a.partition_S(cSFA_k); + Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tSFBgSFB_k = thr_scale_copy_b.partition_S(gSFB_k); + Tensor tSFBcSFB_k = thr_scale_copy_b.partition_S(cSFB_k); + Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); + + Tensor tSFApSFA = make_tensor(shape(filter_zeros(tSFAsSFA(_,_,_,_0{})))); // (CPY,CPY_M,CPY_K) + Tensor tSFBpSFB = make_tensor(shape(filter_zeros(tSFBsSFB(_,_,_,_0{})))); // (CPY,CPY_N,CPY_K) + + auto SFA_shape = shape(mainloop_params.layout_SFA); + auto SFB_shape = shape(mainloop_params.layout_SFB); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB + bool load_sfa = thread_idx < ScaleMsPerTile; + Tensor tSFAcSFA = tSFAcSFA_k(_,_,_,*k_tile_iter); + Tensor tSFAcSFA_compact = filter_zeros(tSFAcSFA); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSFApSFA); ++i) { + tSFApSFA(i) = load_sfa && elem_less(tSFAcSFA_compact(i), SFA_shape); + } + + bool load_sfb = thread_idx < ScaleNsPerTile; + Tensor tSFBcSFB = tSFBcSFB_k(_,_,_,*k_tile_iter); + Tensor tSFBcSFB_compact = filter_zeros(tSFBcSFB); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSFBpSFB); ++i) { + tSFBpSFB(i) = load_sfb && elem_less(tSFBcSFB_compact(i), SFB_shape); + } + int write_stage = smem_pipe_write.index(); + // Copy scale tensors from global memory to shared memory + if constexpr (!IsTmaLoadSFA) { + copy_if(scale_copy_a, tSFApSFA, filter_zeros(tSFAgSFA_k(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,write_stage))); + } + if constexpr (!IsTmaLoadSFB) { + copy_if(scale_copy_b, tSFBpSFB, filter_zeros(tSFBgSFB_k(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,write_stage))); + } + if constexpr (!IsTmaLoadSFA || !IsTmaLoadSFB) { + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor); + } + } + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor1, + class ScaleFactor2 + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor1, scaleFactor2); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor1, scaleFactor2); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), make_layout( + make_shape(get<0>(shape(SmemLayoutSFA{})), + get<1>(TileShape{}), + make_shape(get<1>(shape(SmemLayoutSFA{})), + get<2>(shape(SmemLayoutSFA{})))), + make_stride(get<0>(stride(SmemLayoutSFA{})), _0{}, + make_stride(get<1>(stride(SmemLayoutSFA{})), get<2>(stride(SmemLayoutSFA{})))) + )); // (BLK_M,BLK_N,(BLK_K,P)) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), make_layout( + make_shape(get<0>(TileShape{}), + get<0>(shape(SmemLayoutSFB{})), + make_shape(get<1>(shape(SmemLayoutSFB{})), + get<2>(shape(SmemLayoutSFB{})))), + make_stride(_0{}, + get<0>(stride(SmemLayoutSFB{})), + make_stride(get<1>(stride(SmemLayoutSFB{})), + get<2>(stride(SmemLayoutSFB{})))) + )); // (BLK_M,BLK_N,(BLK_K,P)) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsSFA = tiled_mma.get_slice(thread_idx).partition_C(sSFA); // (MMA,MMA_M,MMA_N,(MMA_K,PIPE)) + Tensor tCsSFB = tiled_mma.get_slice(thread_idx).partition_C(sSFB); // (MMA,MMA_M,MMA_N,(MMA_K,PIPE)) + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + // Since scale factors always broadcast across MMA_K we slice that away + Tensor tCrSFA = make_tensor_like(tCsSFA(_, _, _, _0{})); // (MMA,MMA_M,MMA_N) + Tensor tCrSFB = make_tensor_like(tCsSFB(_, _, _, _0{})); // (MMA,MMA_M,MMA_N) + + // Prologue GMMAs + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + { + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers + copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); + copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + warpgroup_fence_operand(accumulation()); + + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { + filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { + filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + } + } + warpgroup_wait<0>(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFA, tCrSFB); + } + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) + { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) + copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); + copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); + + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { + filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { + filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + } + } + warpgroup_wait<0>(); + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFA, tCrSFB); + } + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_release; + } + if (k_tile_count) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) + copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA); + copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_fence_operand(accumulation()); + + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{}); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrSFB(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) { + filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrSFA(_0{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) { + filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a; + } + } + warpgroup_wait<0>(); + pipeline.consumer_release(smem_pipe_release); // Unlock previous tile + // Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrSFA, tCrSFB); + } + } + if constexpr (ScalePromotionInterval != 4) { + // residues only exists when granularity is not the finnest + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrSFA(_0{}); + accumulation.scale_residue_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(tCrSFA); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrSFB); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrSFA, tCrSFB); + } + } + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // The pipeline is not released in the first iteration + smem_pipe_release.advance(k_tile_count - 1); + pipeline.consumer_release(smem_pipe_release); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..220e996a8611a4e3f666380ddab977bc535849a8 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,748 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class LayoutPairAE_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedSparse, + TileShape_, + ElementA_, + LayoutPairAE_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparse; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutPairAE = LayoutPairAE_; + using LayoutA = remove_cvref_t(LayoutPairAE{}))>; + using LayoutE = remove_cvref_t(LayoutPairAE{}))>; + using StrideA = decltype(cute::stride(LayoutA{})); + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + static_assert(is_sparse::value, "ElementAMma is sparse"); + static_assert(!is_sparse::value, "ElementA is not sparse"); + + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + // LayoutA is nested in the stride due to the sparsity. + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{}.stride())), Int>; + static constexpr bool is_B_mn_major = cutlass::gemm::detail::is_major<0,StrideB>(); + + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape{}),_128{}))>; + + // The offline permutation for the metadata. + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + // Metadata pathways + using SmemCopyAtomE = AutoVectorizingCopy; + using GmemCopyAtomE = GmemTiledCopyA; + + using CtaShape_MNK = TileShape; + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::sparse_elem, + cutlass::tfloat32_t, + uint_bit_t>>>; + using TmaInternalElementB = cute::conditional_t, + tfloat32_t, + uint_bit_t>>; + + struct SharedStorage + { + struct TensorStorage { + alignas(128) cute::ArrayEngine> smem_A; + alignas(128) cute::ArrayEngine> smem_B; + alignas(128) cute::ArrayEngine> smem_E; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 0; + + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutE{})) * cute::sizeof_bits_v); + + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutB{})) * cute::sizeof_bits_v); + + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{}; + LayoutA layout_a{}; + ElementB const* ptr_B{}; + StrideB dB{}; + ElementE const* ptr_E{}; + LayoutE layout_e{}; + }; + + // Device side kernel params + struct Params { + + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_E = decltype(make_tma_copy_A_sm90( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + make_tensor(recast_ptr(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along M mode for this N load, if any + + TMA_A tma_load_a; + TMA_E tma_load_e; + TMA_B tma_load_b; + LayoutA layout_a; + LayoutE layout_e; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_E = recast_ptr(args.ptr_E); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any + + typename Params::TMA_E tma_load_e = make_tma_copy_A_sm90( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along M mode for this N load, if any + + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_e, + tma_load_b, + args.layout_a, + args.layout_e, + transaction_bytes + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool size_check = true; + // Check Alignment A + if constexpr (is_A_mn_major) { + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(_1{}, M, M*K/2)); + } + else { // If A is K-major + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(K/2, _1{}, M*K/2)); + } + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!size_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + // Check if layout_a and layout_e is filled correctly + auto layout_a_ref = SparseConfig::fill_layoutA(problem_shape_MNKL); + auto layout_e_ref = SparseConfig::fill_layoutE(problem_shape_MNKL); + bool layout_check = true; + layout_check = layout_check && (layout_a_ref == args.layout_a); + layout_check = layout_check && (layout_e_ref == args.layout_e); + + if (!layout_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Layout_a/e mismatch.\n"); + } + + return size_check && layout_check; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_e.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(mainloop_params.layout_a.shape()); // (m,k,l) + Tensor mE_mkl = mainloop_params.tma_load_e.get_tma_tensor(mainloop_params.layout_e.shape()); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gE_mkl = local_tile(mE_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, gE_mkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, class TensorE, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + auto [gA_mkl, gB_nkl, gE_mkl] = load_inputs; + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_e = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_e = mainloop_params.tma_load_e.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(get<0>(cta_coord_mnk)); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gE = gE_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tEgE = block_tma_e.partition_S(gE); // (TMA,TMA_M,TMA_K,k) + Tensor tEsE = block_tma_e.partition_D(sE); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_e.with(*tma_barrier, mcast_mask_e), tEgE(_,_,_,*k_tile_iter), tEsE(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutE{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sE = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{})); // (BLK_M,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + auto copy_atom_E = Copy_Atom{}; + + Tensor tCsE = partition_E(thread_mma, sE(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrE = make_fragment_like(tCsE); // (MMA,MMA_M,MMA_K) + + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(thread_idx); + + Tensor tEsE = smem_thr_copy_E.partition_S(sE); // (ECPY,ECPY_M,ECPY_K) + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); // (ECPY,ECPY_M,ECPY_K) + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + +private: + + template + CUTE_HOST_DEVICE static constexpr + auto + thrfrg_E(TiledMMA const& mma, ETensor&& etensor) + { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = e_tensor.compose(AtomLayoutE_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + get_layoutE_TV(TiledMMA const& mma) + { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(mma.thr_layout_vmnk_), size<2>(mma.thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + partition_E(ThrMMA const& thr_mma, ETensor&& etensor) + { + auto thr_tensor = make_tensor(static_cast(etensor).data(), thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_mma.thr_vmnk_), make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_tiled_copy_E(Copy_Atom const& copy_atom, + TiledMMA const& mma) + { + return make_tiled_copy_impl(copy_atom, get_layoutE_TV(mma), make_shape(tile_size<0>(mma),tile_size<2>(mma))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d993d9a1f84635327ca24777ab9a49737973fd34 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,774 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm90_sparse_config.inl" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class LayoutPairAE_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedSparseFP8, + TileShape_, + ElementA_, + LayoutPairAE_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedSparseFP8; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaRaw = typename ElementAMma::raw_type; + using LayoutPairAE = LayoutPairAE_; + using LayoutA = remove_cvref_t(LayoutPairAE{}))>; + using LayoutE = remove_cvref_t(LayoutPairAE{}))>; + using StrideA = decltype(cute::stride(LayoutA{})); + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementE = typename ElementEMma::raw_type; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + static_assert(is_sparse::value, "ElementAMma is sparse"); + static_assert(!is_sparse::value, "ElementA is not sparse"); + + static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; + static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; + + // LayoutA is nested in the stride due to the sparsity. + static constexpr bool is_A_mn_major = cute::is_same_v(LayoutA{}.stride())), Int>; + static constexpr bool is_B_mn_major = cutlass::gemm::detail::is_major<0,StrideB>(); + + using SparseConfig = cutlass::Sm90GemmSparseConfig(TileShape{}),_128{}))>; + + // The offline permutation for the metadata. + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + // Metadata pathways + using SmemCopyAtomE = AutoVectorizingCopy; + using GmemCopyAtomE = GmemTiledCopyA; + + using CtaShape_MNK = TileShape; + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::sparse_elem, + cutlass::tfloat32_t, + uint_bit_t>>>; + using TmaInternalElementB = cute::conditional_t, + tfloat32_t, + uint_bit_t>>; + + struct SharedStorage + { + struct TensorStorage { + alignas(128) cute::ArrayEngine> smem_A; + alignas(128) cute::ArrayEngine> smem_B; + alignas(128) cute::ArrayEngine> smem_E; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 0; + + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutE{})) * cute::sizeof_bits_v); + + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutB{})) * cute::sizeof_bits_v); + + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{}; + LayoutA layout_a{}; + ElementB const* ptr_B{}; + StrideB dB{}; + ElementE const* ptr_E{}; + LayoutE layout_e{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_E = decltype(make_tma_copy_A_sm90( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + make_tensor(recast_ptr(nullptr), LayoutE{}), + SmemLayoutE{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along N mode for this M load, if any + + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); // mcast along M mode for this N load, if any + + TMA_A tma_load_a; + TMA_E tma_load_e; + TMA_B tma_load_b; + LayoutA layout_a; + LayoutE layout_e; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_E = recast_ptr(args.ptr_E); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, args.layout_a); + Tensor tensor_e = make_tensor(ptr_E, args.layout_e); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any + + typename Params::TMA_E tma_load_e = make_tma_copy_A_sm90( // use uint64_t to get the largest loading box. + GmemCopyAtomE{}, + tensor_e, + SmemLayoutE{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); // mcast along M mode for this N load, if any + + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_e, + tma_load_b, + args.layout_a, + args.layout_e, + transaction_bytes, + args.mma_promotion_interval + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool size_check = true; + // Check Alignment A + if constexpr (is_A_mn_major) { + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(_1{}, M, M*K/2)); + } + else { // If A is K-major + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(M,K/2,L), cute::make_stride(K/2, _1{}, M*K/2)); + } + size_check = size_check && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!size_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + // Check if layout_a and layout_e is filled correctly + auto layout_a_ref = SparseConfig::fill_layoutA(problem_shape_MNKL); + auto layout_e_ref = SparseConfig::fill_layoutE(problem_shape_MNKL); + bool layout_check = true; + layout_check = layout_check && (layout_a_ref == args.layout_a); + layout_check = layout_check && (layout_e_ref == args.layout_e); + + if (!layout_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Layout_a/e mismatch.\n"); + } + + /* MMA promotion interval should be a multiple of the number of MMA instructions issued by each mainloop iteration. */ + bool interval_check = args.mma_promotion_interval % (size<2>(TileShape{}) / TiledMma().template tile_size_mnk<2>()) == 0; + + if (!interval_check) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: MMA promotion interval is not a multiple of number of MMA instructions per tile.\n"); + } + + return size_check && layout_check && interval_check; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_e.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(mainloop_params.layout_a.shape()); // (m,k,l) + Tensor mE_mkl = mainloop_params.tma_load_e.get_tma_tensor(mainloop_params.layout_e.shape()); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gE_mkl = local_tile(mE_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, gE_mkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, class TensorE, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sE = make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + auto [gA_mkl, gB_nkl, gE_mkl] = load_inputs; + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_e = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_e = mainloop_params.tma_load_e.get_slice(get<1>(cta_coord_mnk)); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(get<0>(cta_coord_mnk)); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gE = gE_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tEgE = block_tma_e.partition_S(gE); // (TMA,TMA_M,TMA_K,k) + Tensor tEsE = block_tma_e.partition_D(sE); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_e.with(*tma_barrier, mcast_mask_e), tEgE(_,_,_,*k_tile_iter), tEsE(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutE{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sE = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_tensors.smem_E.begin()), SmemLayoutE{})); // (BLK_M,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + auto copy_atom_E = Copy_Atom{}; + + Tensor tCsE = partition_E(thread_mma, sE(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrE = make_fragment_like(tCsE); // (MMA,MMA_M,MMA_K) + + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(thread_idx); + + Tensor tEsE = smem_thr_copy_E.partition_S(sE); // (ECPY,ECPY_M,ECPY_K) + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); // (ECPY,ECPY_M,ECPY_K) + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + accumulation.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + int read_stage = smem_pipe_read.index(); + + // Load metadata smem->rmem for one stage + copy(smem_tiled_copy_E, tEsE(_,_,_,read_stage), tErE); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, make_zip_tensor(tCrA(_,_,k_block,read_stage), tErE(_,_,k_block)), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + accumulation.promote_if_needed(); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + +private: + + template + CUTE_HOST_DEVICE static constexpr + auto + thrfrg_E(TiledMMA const& mma, ETensor&& etensor) + { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = e_tensor.compose(AtomLayoutE_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + get_layoutE_TV(TiledMMA const& mma) + { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(mma.thr_layout_vmnk_), size<2>(mma.thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + partition_E(ThrMMA const& thr_mma, ETensor&& etensor) + { + auto thr_tensor = make_tensor(static_cast(etensor).data(), thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_mma.thr_vmnk_), make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_tiled_copy_E(Copy_Atom const& copy_atom, + TiledMMA const& mma) + { + return make_tiled_copy_impl(copy_atom, get_layoutE_TV(mma), make_shape(tile_size<0>(mma),tile_size<2>(mma))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/base_grouped.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/base_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..d9c2423b2bfe384695d83cad1737e2bbfc1e0f62 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/base_grouped.h @@ -0,0 +1,478 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Base device-level grouped kernel. +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class BaseGrouped { +public: + + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + +protected: + + /// Kernel parameters object + typename BaseKernel::Params params_; + +private: + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes, cudaStream_t stream = nullptr) { + cudaError_t cuda_error = cudaMemcpyAsync(workspace, data, bytes, cudaMemcpyHostToDevice, stream); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaMemcpy() returned error " + << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const &args, int32_t tile_count, void* workspace, cudaStream_t stream = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes, stream); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + +public: + + /// Constructs the GEMM. + BaseGrouped() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const &problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const &args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + return BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, + args.problem_count, + args.threadblock_count); + } else { + return 0; + } + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + BaseKernel::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_ptr, + int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, + int64_t* ldc_host_ptr, + int64_t* ldd_host_ptr, + int64_t* offset_A_ptr, + int64_t* offset_B_ptr, + int64_t* offset_C_ptr, + int64_t* offset_D_ptr) + { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient(const cutlass::gemm::GemmCoord* problem_sizes_ptr=nullptr, + int problem_count=0, + int available_sm_count=-1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; + } + + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, + cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) { + available_sm_count = multiprocessor_count; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + params_.update(args, workspace, tile_count); + } else { + params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // + // Configure grid and block dimensions + // + + if (!params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + dim3 grid(params_.threadblock_count, 1, 1); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + // + // Launch kernel + // + + // Launch + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Initializes and runs the kernel. + Status operator()( + Arguments const &args, + void *workspace, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h new file mode 100644 index 0000000000000000000000000000000000000000..75edf2fc2c92ad344f4e30790b80e8185e0744db --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/default_gemm_configuration.h @@ -0,0 +1,955 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Definitions for GEMM structures +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename OperatorClass, + typename ArchTag, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator +> +struct DefaultGemmConfiguration; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfiguration< + arch::OpClassSimt, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator> { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + using ThreadblockShape = GemmShape<128, 128, 8>; + using WarpShape = GemmShape<32, 64, 8>; + using InstructionShape = GemmShape<1, 1, 1>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >; + + using Operator = arch::OpMultiplyAdd; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename ElementC> +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 4; + static int const kAlignmentB = 4; + using ThreadblockShape = GemmShape<128, 128, 32>; + using WarpShape = GemmShape<32, 64, 32>; + using InstructionShape = GemmShape<1, 1, 4>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, + 1, + int32_t, + float + >; + + using Operator = arch::OpMultiplyAdd; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfiguration< + arch::OpClassWmmaTensorOp, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, + 128 / sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Operator = arch::OpMultiplyAdd; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm70, + ElementA, + ElementB, + ElementC, + ElementAccumulator> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 32>; + using WarpShape = GemmShape<64, 64, 32>; + using InstructionShape = GemmShape<8, 8, 4>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, + 128 / sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Operator = arch::OpMultiplyAdd; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + ElementA, + ElementB, + ElementC, + ElementAccumulator> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + using ThreadblockShape = GemmShape<128, 256, 32>; + using WarpShape = GemmShape<64, 64, 32>; + using InstructionShape = GemmShape<16, 8, 8>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, + 128 / sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Operator = typename platform::conditional< + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + int8_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<8, 8, 16>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + int8_t, + uint8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<8, 8, 16>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + uint8_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<8, 8, 16>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + uint8_t, + uint8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<8, 8, 16>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + int4b_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<8, 8, 32>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + int4b_t, + uint4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<8, 8, 32>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + uint4b_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<8, 8, 32>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + uint4b_t, + uint4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<8, 8, 32>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm75, + uint1b_t, + uint1b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 512>; + using WarpShape = GemmShape<64, 64, 512>; + using InstructionShape = GemmShape<8, 8, 128>; + static int const kStages = 2; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpXorPopc; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 16>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = typename platform::conditional< + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type; +}; + +//////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<128, 128, 16>; + using WarpShape = GemmShape<32, 64, 16>; + using InstructionShape = GemmShape<8, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 1, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + + +template <> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + complex, + complex, + complex, + complex + > { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<64, 64, 16>; + using WarpShape = GemmShape<32, 32, 16>; + using InstructionShape = GemmShape<8, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + complex, 1, complex, + complex>; + + using Operator = arch::OpMultiplyAddComplex; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + uint8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint8_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint8_t, + uint8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + uint4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint4b_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint4b_t, + uint4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 128>; + using WarpShape = GemmShape<64, 64, 128>; + using InstructionShape = GemmShape<16, 8, 64>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + uint1b_t, + uint1b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 512>; + using WarpShape = GemmShape<64, 64, 512>; + using InstructionShape = GemmShape<16, 8, 256>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAdd; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int4b_t, + int8_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementC> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm80, + int8_t, + int4b_t, + ElementC, + int32_t> { + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp< + ElementC, 128 / sizeof_bits::value, int32_t, float>; + + using Operator = arch::OpMultiplyAddSaturate; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Base configuration for all {fe4m3, fe5m2} x {fe4m3, fe5m2} combinations on SM89 +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementAccumulator> +struct DefaultGemmConfigurationSm89F8 { + static_assert((platform::is_same::value || + platform::is_same::value), + "ElementA must be of type float_e4m3_t or float_e5m2_t"); + static_assert((platform::is_same::value || + platform::is_same::value), + "ElementB must be of type float_e4m3_t or float_e5m2_t"); + + static int const kAlignmentA = 128 / sizeof_bits::value; + static int const kAlignmentB = 128 / sizeof_bits::value; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 32>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + +/// Partial specialization for SM89 fe4m3 x fe4m3 +template +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm89, + cutlass::float_e4m3_t, + cutlass::float_e4m3_t, + ElementC, + ElementAccumulator> : DefaultGemmConfigurationSm89F8< + cutlass::float_e4m3_t, + cutlass::float_e4m3_t, + ElementC, + ElementAccumulator> {}; + +/// Partial specialization for SM89 fe4m3 x fe5m2 +template +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm89, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t, + ElementC, + ElementAccumulator> : DefaultGemmConfigurationSm89F8< + cutlass::float_e4m3_t, + cutlass::float_e5m2_t, + ElementC, + ElementAccumulator> {}; + +/// Partial specialization for SM89 fe5m2 x fe4m3 +template +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm89, + cutlass::float_e5m2_t, + cutlass::float_e4m3_t, + ElementC, + ElementAccumulator> : DefaultGemmConfigurationSm89F8< + cutlass::float_e5m2_t, + cutlass::float_e4m3_t, + ElementC, + ElementAccumulator> {}; + +/// Partial specialization for SM89 fe5m2 x fe5m2 +template +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm89, + cutlass::float_e5m2_t, + cutlass::float_e5m2_t, + ElementC, + ElementAccumulator> : DefaultGemmConfigurationSm89F8< + cutlass::float_e5m2_t, + cutlass::float_e5m2_t, + ElementC, + ElementAccumulator> {}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 1, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + +template <> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm90, + complex, + complex, + complex, + complex + > { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<64, 64, 16>; + using WarpShape = GemmShape<32, 32, 16>; + using InstructionShape = GemmShape<16, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + complex, 1, complex, + complex>; + + using Operator = arch::OpMultiplyAddComplex; +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/ell_gemm.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/ell_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..097debf5bed5e356881f8ef7e8515d726645f8d6 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/ell_gemm.h @@ -0,0 +1,849 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a Block-Ell sparse gemm kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/ell_gemm.h" + +#include "cutlass/gemm/kernel/default_ell_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Blocked-Ell sparse gemm device-level operator. This is an interface to efficient CUTLASS + Blocked-Ell kernels that may be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to Blocked-Ell problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + Example of a CUTLASS EllGemm operator is as follows: + + // + // Instantiate the CUTLASS EllGemm operator. + // + + cutlass::gemm::device::EllGemm< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + cutlass::half_t, 128 / cutlass::sizeof_bits::value, + float, float>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, // Stages + 128 / cutlass::sizeof_bits::value, // Alignment A + 128 / cutlass::sizeof_bits::value // Alignment B + > ellgemm_op; + + // + // Launch the EllGemm operation on the device + // + + Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format: + a_rows - Rows in the sparse matrix. + a_cols - Columns in the sparse matrix. + BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in + consecutive blocks, whose size is (a_rows * a_ell_num_columns) + ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is + (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) + a_ell_blocksize - Size of the ELL-Blocks. + a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) + B - Input dense matrix whose size is (a_cols * n) + C/D - Output dense matrix whose size is (a_rows * n) + + cutlass::Status status = ellgemm_op({ + {a_rows, n, a_cols}, // GemmCoord problem_size + {BlockedEllA, lda}, // TensorRef ref_BlockedEllA + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + ell_idx, // Blocked-ELL Column indices or ellColInd matrix (const int*) + a_ell_num_columns, // Columns in the Blocked-Ellpack (ellValue) matrix (int) + a_ell_blocksize, // Size of the ELL-Blocks (int) + a_ell_base, // Base index of ellColInd (int) - Zero or One + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + + /// Access granularity of A matrix in units of elements + int AlignmentA, + + /// Access granularity of B matrix in units of elements + int AlignmentB, + + /// Supports split-K with serial reduction + bool SplitKSerial, + + /// Operation performed by GEMM + typename Operator, + + /// Sparse matrix is A or not + bool IsASparse + > + class EllGemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Sparse matrix is A or not + bool IsASparse = true + > +class EllGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kIsASparse = IsASparse; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultEllGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kIsASparse + >::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + const int* ell_idx; + int ell_ncol; + int ell_blocksize; + int ell_base_idx; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + const int* ell_idx_, + int ell_ncol_, + int ell_blocksize_, + int ell_base_idx_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ell_idx(ell_idx_), + ell_ncol(ell_ncol_), + ell_blocksize(ell_blocksize_), + ell_base_idx(ell_base_idx_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_{}; + +public: + + /// Constructs the GEMM. + EllGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {args.ell_blocksize, + ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + tiled_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ell_idx, + args.ell_ncol, + args.ell_blocksize, + args.ell_base_idx, + args.epilogue, + static_cast(workspace) + }; + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {args.ell_blocksize, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + grid_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + return set(args, grid_shape, workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K as a serial reduction + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator_, + /// Sparse matrix is A or not + bool IsASparse> +class EllGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kSplitKSerial = SplitKSerial; + static bool const kIsASparse = false; + + using UnderlyingOperator = EllGemm< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + SplitKSerial, + Operator, + kIsASparse + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + const int* ell_idx; + int ell_ncol; + int ell_blocksize; + int ell_base_idx; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + const int* ell_idx_, + int ell_ncol_, + int ell_blocksize_, + int ell_base_idx_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ell_idx(ell_idx_), + ell_ncol(ell_ncol_), + ell_blocksize(ell_blocksize_), + ell_base_idx(ell_base_idx_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + EllGemm() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + {args.ref_A.data(), args.ref_A.stride(0)}, + {args.ref_C.data(), args.ref_C.stride(0)}, + {args.ref_D.data(), args.ref_D.stride(0)}, + args.ell_idx, + args.ell_ncol, + args.ell_blocksize, + args.ell_base_idx, + args.epilogue, + args.split_k_slices + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, + args.split_k_slices); + + tiled_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ + // Initialize the Params structure + return underlying_operator_.set(to_underlying_arguments(args), grid_shape, workspace); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, + args.split_k_slices); + + grid_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + set(args, grid_shape, workspace); + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..f4ea4ebe86bedabc28b3ea667dcd8f735b667868 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm.h @@ -0,0 +1,772 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may + be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most plausible GEMM + configurations for each supported architecture. Consequently, not all parameters are exposed + to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy + are selected to tradeoff simplicity of the interface with flexibility. We expect + most configurations to be specified at this level. Applications with more exotic requirements + may construct their kernels of interest using CUTLASS components at the threadblock, warp, + and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects compose some + internal state with an overloaded function call operator. This enables decoupling of + initialization from execution, possibly reducing overhead during steady state phases of + application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each logical + input to the computation. This is distinct from the kernel-level Params structure pattern + which contains application-specific precomputed state needed by the device code. + + Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN + is as follows: + + // + // Instantiate the CUTLASS GEMM operator. + // + + cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor + > gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class Gemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + SharedMemoryClearOption::kNone, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1, + int const *gather_A_indices_ = nullptr, + int const *gather_B_indices_ = nullptr, + int const *scatter_D_indices_ = nullptr + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_), + split_k_slices(split_k_slices), + gather_A_indices(gather_A_indices_), + gather_B_indices(gather_B_indices_), + scatter_D_indices(scatter_D_indices_) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + Gemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K as a serial reduction + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator_, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout +> +class Gemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kSplitKSerial = SplitKSerial; + + using UnderlyingOperator = Gemm< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + SplitKSerial, + Operator, + GatherB, + GatherA, + ScatterD, + PermuteDLayout + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + // For gather+scatter operations + int *gather_A_indices; + int *gather_B_indices; + int *scatter_D_indices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1, + int *gather_A_indices_ = nullptr, + int *gather_B_indices_ = nullptr, + int *scatter_D_indices_ = nullptr + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_), + split_k_slices(split_k_slices), + gather_A_indices(gather_A_indices_), + gather_B_indices(gather_B_indices_), + scatter_D_indices(scatter_D_indices_) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + Gemm() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + {args.ref_A.data(), args.ref_A.stride(0)}, + {args.ref_C.data(), args.ref_C.stride(0)}, + {args.ref_D.data(), args.ref_D.stride(0)}, + args.epilogue, + args.split_k_slices, + args.gather_B_indices, + args.gather_A_indices, + args.scatter_D_indices + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_array.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_array.h new file mode 100644 index 0000000000000000000000000000000000000000..ab5ed26b0d5008d9164661a2b1763f86540b41c5 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_array.h @@ -0,0 +1,738 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_array.h" + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may + be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most plausible GEMM + configurations for each supported architecture. Consequently, not all parameters are exposed + to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy + are selected to tradeoff simplicity of the interface with flexibility. We expect + most configurations to be specified at this level. Applications with more exotic requirements + may construct their kernels of interest using CUTLASS components at the threadblock, warp, + and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects compose some + internal state with an overloaded function call operator. This enables decoupling of + initialization from execution, possibly reducing overhead during steady state phases of + application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each logical + input to the computation. This is distinct from the kernel-level Params structure pattern + which contains application-specific precomputed state needed by the device code. + + Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN + is as follows: + + // + // Instantiate the CUTLASS GEMM operator. + // + + cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor + > gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator +> +class GemmArray { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + using Operator = Operator_; + + /// Define the kernel + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + false, + Operator + >::GemmKernel; + + using GemmKernel = kernel::GemmArray; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + + ElementA const * const *ptr_A; + LayoutA layout_A; + + ElementB const * const *ptr_B; + LayoutB layout_B; + + ElementC const * const *ptr_C; + LayoutC layout_C; + + ElementC * const * ptr_D; + LayoutC layout_D; + + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + ElementA const * const *ptr_A_, + LayoutA layout_A_, + ElementB const * const *ptr_B_, + LayoutB layout_B_, + ElementC const * const *ptr_C_, + LayoutC layout_C_, + ElementC * const * ptr_D_, + LayoutC layout_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ptr_A(ptr_A_), + layout_A(layout_A_), + ptr_B(ptr_B_), + layout_B(layout_B_), + ptr_C(ptr_C_), + layout_C(layout_C_), + ptr_D(ptr_D_), + layout_D(layout_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + GemmArray() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (args.layout_A.stride(0) % kAlignmentA) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_B.stride(0) % kAlignmentB) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_C.stride(0) % kAlignmentC) { + return Status::kErrorMisalignedOperand; + } + + if (args.layout_D.stride(0) % kAlignmentC) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ptr_A, + args.layout_A, + args.ptr_B, + args.layout_B, + args.ptr_C, + args.layout_C, + args.ptr_D, + args.layout_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + args.batch_count, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}); + + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ptr_A, + args.layout_A, + args.ptr_B, + args.layout_B, + args.ptr_C, + args.layout_C, + args.ptr_D, + args.layout_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + typename Operator_ +> +class GemmArray< + ElementA_, + LayoutA_, + ElementB_, + LayoutB_, + ElementC_, + layout::ColumnMajor, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + AlignmentA, + AlignmentB, + Operator_ +> { +public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = false; + + // + using UnderlyingOperator = GemmArray< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + + ElementA const * const *ptr_A; + LayoutA layout_A; + + ElementB const * const *ptr_B; + LayoutB layout_B; + + ElementC const * const *ptr_C; + LayoutC layout_C; + + ElementC * const * ptr_D; + LayoutC layout_D; + + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + ElementA const * const *ptr_A_, + LayoutA layout_A_, + ElementB const * const *ptr_B_, + LayoutB layout_B_, + ElementC const * const *ptr_C_, + LayoutC layout_C_, + ElementC * const * ptr_D_, + LayoutC layout_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ptr_A(ptr_A_), + layout_A(layout_A_), + ptr_B(ptr_B_), + layout_B(layout_B_), + ptr_C(ptr_C_), + layout_C(layout_C_), + ptr_D(ptr_D_), + layout_D(layout_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmArray() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + + GemmCoord problem_size{ + args.problem_size.n(), + args.problem_size.m(), + args.problem_size.k() + }; + + return UnderlyingArguments( + problem_size, + args.ptr_B, + args.layout_B.stride(), + args.ptr_A, + args.layout_A.stride(), + args.ptr_C, + args.layout_C.stride(), + args.ptr_D, + args.layout_D.stride(), + args.epilogue, + args.batch_count + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_batched.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_batched.h new file mode 100644 index 0000000000000000000000000000000000000000..4a5b4105b3ad23151c534f0bd42884a33fe296a3 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_batched.h @@ -0,0 +1,704 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined batch GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_batched.h" + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may + be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most plausible GEMM + configurations for each supported architecture. Consequently, not all parameters are exposed + to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy + are selected to tradeoff simplicity of the interface with flexibility. We expect + most configurations to be specified at this level. Applications with more exotic requirements + may construct their kernels of interest using CUTLASS components at the threadblock, warp, + and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects compose some + internal state with an overloaded function call operator. This enables decoupling of + initialization from execution, possibly reducing overhead during steady state phases of + application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each logical + input to the computation. This is distinct from the kernel-level Params structure pattern + which contains application-specific precomputed state needed by the device code. + + Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN + is as follows: + + // + // Instantiate the CUTLASS GEMM operator. + // + + cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor + > gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator +> +class GemmBatched { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + using Operator = Operator_; + + /// Define the kernel + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + false, + Operator + >::GemmKernel; + + using GemmKernel = kernel::GemmBatched; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + int64_t stride_A; + TensorRef ref_B; + int64_t stride_B; + TensorRef ref_C; + int64_t stride_C; + TensorRef ref_D; + int64_t stride_D; + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + int64_t stride_A_, + TensorRef ref_B_, + int64_t stride_B_, + TensorRef ref_C_, + int64_t stride_C_, + TensorRef ref_D_, + int64_t stride_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ref_A(ref_A_), + stride_A(stride_A_), + ref_B(ref_B_), + stride_B(stride_B_), + ref_C(ref_C_), + stride_C(stride_C_), + ref_D(ref_D_), + stride_D(stride_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + GemmBatched() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.stride_A, + args.ref_B.non_const_ref(), + args.stride_B, + args.ref_C.non_const_ref(), + args.stride_C, + args.ref_D, + args.stride_D, + args.epilogue, + args.batch_count + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + typename Operator_ +> +class GemmBatched< + ElementA_, + LayoutA_, + ElementB_, + LayoutB_, + ElementC_, + layout::ColumnMajor, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + AlignmentA, + AlignmentB, + Operator_ +> { +public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = false; + + // + using UnderlyingOperator = GemmBatched< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + int64_t stride_A; + TensorRef ref_B; + int64_t stride_B; + TensorRef ref_C; + int64_t stride_C; + TensorRef ref_D; + int64_t stride_D; + typename EpilogueOutputOp::Params epilogue; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + int64_t stride_A_, + TensorRef ref_B_, + int64_t stride_B_, + TensorRef ref_C_, + int64_t stride_C_, + TensorRef ref_D_, + int64_t stride_D_, + typename EpilogueOutputOp::Params epilogue_, + int batch_count_ + ): + problem_size(problem_size_), + ref_A(ref_A_), + stride_A(stride_A_), + ref_B(ref_B_), + stride_B(stride_B_), + ref_C(ref_C_), + stride_C(stride_C_), + ref_D(ref_D_), + stride_D(stride_D_), + epilogue(epilogue_), + batch_count(batch_count_) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmBatched() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + args.stride_B, + {args.ref_A.data(), args.ref_A.stride(0)}, + args.stride_A, + {args.ref_C.data(), args.ref_C.stride(0)}, + args.stride_C, + {args.ref_D.data(), args.ref_D.stride(0)}, + args.stride_D, + args.epilogue, + args.batch_count + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_complex.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..b0403230af18a8c12983a8e9d8b71d840d4f84f7 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_complex.h @@ -0,0 +1,718 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM + kernels that may be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters + onto specific CUTLASS components. + + 2. At runtime, it maps logical arguments to GEMM problems to kernel + parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most + plausible GEMM configurations for each supported architecture. Consequently, + not all parameters are exposed to the top-level interface. Rather, sensible + defaults at each level of the CUTLASS hierarchy are selected to tradeoff + simplicity of the interface with flexibility. We expect most configurations to + be specified at this level. Applications with more exotic requirements may + construct their kernels of interest using CUTLASS components at the + threadblock, warp, and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects + compose some internal state with an overloaded function call operator. This + enables decoupling of initialization from execution, possibly reducing + overhead during steady state phases of application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each + logical input to the computation. This is distinct from the kernel-level + Params structure pattern which contains application-specific precomputed state + needed by the device code. + + Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's + SGEMM NN is as follows: + + // + // Instantiate the CUTLASS GEMM operator. + // + + cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor + > gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Multiply-add operator + // (selects complex or gaussian complex) + typename Operator_ = arch::OpMultiplyAddComplex, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false> +class GemmComplex { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + using Operator = Operator_; + static bool const kSplitKSerial = SplitKSerial; + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultGemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kTransformA, + kTransformB, + Operator, + kSplitKSerial + >::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + GemmComplex() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + if (kSplitKSerial && args.split_k_slices > 1) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace) + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Multiply-add operator + // (selects complex or gaussian complex) + typename Operator_, + /// If true, kernel supports split-K as a serial reduction + bool SplitKSerial +> +class GemmComplex< + ElementA_, + LayoutA_, + ElementB_, + LayoutB_, + ElementC_, + layout::ColumnMajor, // partially specialized on LayoutC + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + TransformA, + TransformB, + Operator_, + SplitKSerial +> { +public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static int const kStages = Stages; + using Operator = Operator_; + static bool const kSplitKSerial = SplitKSerial; + + using UnderlyingOperator = GemmComplex< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + TransformB, + TransformA, + Operator, + SplitKSerial + >; + + static int const kAlignmentA = UnderlyingOperator::kAlignmentB; + static int const kAlignmentB = UnderlyingOperator::kAlignmentA; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + static ComplexTransform const kTransformA = UnderlyingOperator::kTransformB; + static ComplexTransform const kTransformB = UnderlyingOperator::kTransformA; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmComplex() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + {args.ref_A.data(), args.ref_A.stride(0)}, + {args.ref_C.data(), args.ref_C.stride(0)}, + {args.ref_D.data(), args.ref_D.stride(0)}, + args.epilogue, + args.split_k_slices + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_grouped.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..3c1c9bc75a81920ed69844b9558d4b3a7b38826c --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_grouped.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Device-level grouped GEMM. +*/ + +#pragma once + +#include "cutlass/gemm/device/base_grouped.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class GemmGrouped : public BaseGrouped { +public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..bdc2e5f327b81524fae86ac37c86cee25e561e20 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Device-level GEMM with layernorm elementwise operations fused in mainloop +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for Scale/Bias vectors + typename ElementScaleBias_, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator +> +class GemmLayernormMainloopFusion : + public GemmUniversalBase< + typename kernel::DefaultGemmLayernormMainloopFusion< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementScaleBias_, + LayoutScaleBias_, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmLayernormMainloopFusion< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementScaleBias_, + LayoutScaleBias_, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for Scale/Bias vectors + typename ElementScaleBias_, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_ +> +class GemmLayernormMainloopFusion { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementScaleBias = ElementScaleBias_; + using LayoutScaleBias = LayoutScaleBias_; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + + using UnderlyingOperator = typename GemmLayernormMainloopFusion< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementScaleBias, + LayoutScaleBias, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmLayernormMainloopFusion() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..57f345f41f625e673ed29254954bc392130a82c1 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse.h @@ -0,0 +1,515 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/sparse_gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may + be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most plausible GEMM + configurations for each supported architecture. Consequently, not all parameters are exposed + to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy + are selected to tradeoff simplicity of the interface with flexibility. We expect + most configurations to be specified at this level. Applications with more exotic requirements + may construct their kernels of interest using CUTLASS components at the threadblock, warp, + and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects compose some + internal state with an overloaded function call operator. This enables decoupling of + initialization from execution, possibly reducing overhead during steady state phases of + application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each logical + input to the computation. This is distinct from the kernel-level Params structure pattern + which contains application-specific precomputed state needed by the device code. + + Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN + is as follows: + + // + // Instantiate the CUTLASS GEMM operator. + // + + cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor + > gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class SparseGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using MathOperator = Operator; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultSparseGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator + >::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_E; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + TensorRef ref_E_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ref_E(ref_E_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + SparseGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ref_E.non_const_ref() + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ref_E.non_const_ref(), + args.epilogue, + static_cast(workspace) + }; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.ref_E.reset(args.ref_E.non_const_ref().data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..2c92030c00157f577ce69acca1a48025d52f4799 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_universal.h @@ -0,0 +1,211 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_sparse_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + GemmSparseUniversal is a stateful, reusable Sparse GEMM handle. Once initialized for a given GEMM computation + (problem geometry and data references), it can be reused across different GEMM problems having the + geometry. (Once initialized, details regarding problem geometry and references to workspace memory + cannot be updated.) + + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class GemmSparseUniversal : + public GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversal< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + static_assert((platform::is_same::value), + "Epilogue of Ampere sparse GEMM must be row major for now."); + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversal< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..c42c82b47f128b57d9fc3002fd7e750565beed66 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h @@ -0,0 +1,202 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_sparse_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class GemmSparseUniversalWithAbsmax : + public GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversalWithAbsmax< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + static_assert((platform::is_same::value), + "Epilogue of Ada sparse GEMM must be row major for now."); + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversalWithAbsmax< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..5b86f123388502f746e011d27cd3ff07df1d5607 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_with_absmax.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a sparse GEMM kernel that computes the absolute maximum of the output tensor + and applies additional scaling factors to operands. +*/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/sparse_gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_with_absmax.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class SparseGemmWithAbsmax { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using MathOperator = Operator; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultSparseGemmWithAbsmax< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator + >::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; + + using Arguments = typename GemmKernel::Arguments; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + SparseGemmWithAbsmax() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ref_E.non_const_ref() + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ref_E.non_const_ref(), + args.ref_Aux, + args.ptr_Vector, + args.ldr, + args.epilogue, + static_cast(workspace) + }; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.ref_E.reset(args.ref_E.non_const_ref().data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..c700733502d12ea17df5dbf5a5beec7b76c0ccec --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_sparse_with_visitor.h @@ -0,0 +1,342 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/sparse_gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Sparse GEMM with visitor + */ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename FusionCallbacks_ = + typename cutlass::epilogue::threadblock::detail::EmptyCallbacks, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Number of stages used in the pipelined epilogue + int EpilogueStages = 1> +class SparseGemmWithVisitor { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using FusionCallbacks = FusionCallbacks_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using MathOperator = Operator; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultSparseGemmWithVisitor< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + FusionCallbacks, + ThreadblockSwizzle, + kStages, + Operator, + EpilogueStages + >::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_E; + typename FusionCallbacks::Arguments epilogue; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_E_, + typename FusionCallbacks::Arguments epilogue_ = + typename FusionCallbacks::Arguments() + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_E(ref_E_), + epilogue(epilogue_) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + SparseGemmWithVisitor() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + cutlass::TensorRef(), // It only matters that it's empty. + cutlass::TensorRef(), // Same as above. + args.ref_E.non_const_ref() + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + constexpr int SplitKSlices = 1; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + SplitKSlices); + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_E.non_const_ref(), + args.epilogue + }; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_E.reset(args.ref_E.non_const_ref().data()); + params_.output_op = args.epilogue; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..1cf506f53d7df39449df73de3034163ccc72606f --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -0,0 +1,636 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for GEMM performing a reduction over K partitions in parallel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_splitk_parallel.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/reduction/kernel/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + Gemm device-level operator performing parallel reduction over the K partition. + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Epilogue output operator + typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert< + ElementAccumulator_, + DefaultGemmConfiguration::EpilogueOutputOp::kCount, + ElementAccumulator_>, + /// Reduction operator + typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator, + EpilogueOutputOp_::kCount>, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + threadblock::GemmSplitKHorizontalThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int kAlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int kAlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class GemmSplitKParallel { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ConvertScaledOp = ConvertScaledOp_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ReductionOp = ReductionOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + + /// GEMM kernel + using GemmKernel = typename kernel::DefaultGemmSplitKParallel< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + ConvertScaledOp, + ThreadblockSwizzle, + kStages, + Operator + >::GemmKernel; + + /// Reduction kernel + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + // + // + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + typename ConvertScaledOp::Params convert; + typename ReductionOp::Params reduction; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1, + typename ConvertScaledOp::Params convert_ = + typename ConvertScaledOp::Params(), + typename ReductionOp::Params reduction_ = + typename ReductionOp::Params() + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_), + split_k_slices(split_k_slices), + convert(convert_), + reduction(reduction_) { } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params gemm_params_; + + /// Reduction kernel parameters object + typename ReductionKernel::Params reduction_params_; + +public: + + /// Constructs the GEMM. + GemmSplitKParallel() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k(); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + // Define a reference to the workspace - this is an aligned region in device memory. + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + TensorRef ref_workspace( + static_cast(workspace), + args.problem_size.n()); + + int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n()); + + // Initialize the Params structure + gemm_params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + ref_workspace, + args.convert, + partition_stride + }; + + reduction_params_ = typename ReductionKernel::Params( + args.problem_size.mn(), + grid_shape.k(), + partition_stride, + ref_workspace, + args.ref_D, + args.ref_C.non_const_ref(), + args.epilogue + ); + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + gemm_params_.ref_A.reset(args.ref_A.data()); + gemm_params_.ref_B.reset(args.ref_B.data()); + gemm_params_.ref_D.reset(workspace); + + reduction_params_.ref_D.reset(args.ref_D.data()); + reduction_params_.ref_C.reset(args.ref_C.data()); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // + // Launch GEMM kernel + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + Kernel<<>>(gemm_params_); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + // + // Launch reduction kernel + // + + block = ReductionKernel::block_shape(); + grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn()); + + Kernel<<< grid, block, 0, stream >>>(reduction_params_); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Epilogue output operator + typename ConvertScaledOp_, + /// Reduction operator + typename ReductionOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, int kAlignmentA, int kAlignmentB, + /// Operation performed by GEMM + typename Operator_> +class GemmSplitKParallel { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ConvertScaledOp = ConvertScaledOp_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ReductionOp = ReductionOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + + using UnderlyingOperator = GemmSplitKParallel< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ConvertScaledOp, + ReductionOp, + ThreadblockSwizzle, + Stages, + kAlignmentA, + kAlignmentB, + Operator + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + using ReductionKernel = typename UnderlyingOperator::ReductionKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + typename ConvertScaledOp::Params convert; + typename ReductionOp::Params reduction; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1, + typename ConvertScaledOp::Params convert_ = + typename ConvertScaledOp::Params(), + typename ReductionOp::Params reduction_ = + typename ReductionOp::Params() + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_), + split_k_slices(split_k_slices), + convert(convert_), + reduction(reduction_) { } + }; + +private: + + /// Kernel parameters object + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmSplitKParallel() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + {args.ref_A.data(), args.ref_A.stride(0)}, + {args.ref_C.data(), args.ref_C.stride(0)}, + {args.ref_D.data(), args.ref_D.stride(0)}, + args.epilogue, + args.split_k_slices, + args.convert, + args.reduction + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..c2c76eb86ddcb659fa9b41184fb362c45884c719 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal.h @@ -0,0 +1,442 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation + (problem geometry and data references), it can be reused across different GEMM problems having the + geometry. (Once initialized, details regarding problem geometry and references to workspace memory + cannot be updated.) + + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout_ = layout::NoPermute, + /// Permute operand A + typename PermuteALayout_ = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout_ = layout::NoPermute +> +class GemmUniversal : + public GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone, + GatherA, + GatherB, + ScatterD, + PermuteDLayout_, + PermuteALayout_, + PermuteBLayout_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmUniversal< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone, + GatherA, + GatherB, + ScatterD, + PermuteDLayout_, + PermuteALayout_, + PermuteBLayout_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout_, + /// Permute operand A + typename PermuteALayout_, + /// Permute operand B + typename PermuteBLayout_ +> +class GemmUniversal { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversal< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA, + GatherB, + GatherA, + ScatterD, + PermuteDLayout, + PermuteBLayout, + PermuteALayout + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversal() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..390e41f899037193ff4b795e9c51b62125854125 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -0,0 +1,784 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, new static methods + are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of GemmUniversalAdapter + on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might + differ between the two specializations. +*/ +template +class GemmUniversalAdapter; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } + else { + return 0; + } +} + +} // namespace detail + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + using GemmKernel = GetUnderlyingKernel_t; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + workspace_bytes += GemmKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = GemmKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 103 + ) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + + [[maybe_unused]] void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + launch_result = cuda_adapter->launch(grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + [[maybe_unused]] void const* kernel = (void const*) device_kernel; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90; + if constexpr (kClusterLaunch) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); +#endif + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + + else { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 120 + || GemmKernel::ArchTag::kMinComputeCapability == 103 + ) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with fall-back cluster"); +#endif + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params, + launch_with_pdl); + } + } + } + + } + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 2.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + + using GemmKernel = GetUnderlyingKernel_t; + + static bool const kInternalTranspose = + !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose + cute::is_same::value; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + // warp-level, arch-level (instruction), math operator + using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + + // Operator class and arch tag extract bottom-up + // set it for top-level gemm device-level template + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using MapArguments = kernel::detail::MapArguments< + typename GemmKernel::ElementA, + typename GemmKernel::LayoutA, + GemmKernel::kTransformA, + GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, + typename GemmKernel::LayoutB, + GemmKernel::kTransformB, + GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, + kInternalTranspose + >; + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static int const kAlignmentA = MapArguments::kAlignmentA; + + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + static int const kAlignmentB = MapArguments::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename MapArguments::LayoutC; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + // C and D same type for 2.x kernel + using ElementD = ElementC; + using LayoutD = LayoutC; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalAdapter() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + if (kInternalTranspose) { + return args.transposed_problem(); + } + else { + return args; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args), cuda_adapter); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), cuda_adapter); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter); + } + + /// Lightweight update given a subset of arguments. + Status update(Arguments const &args) { + + return underlying_operator_.update(to_underlying_arguments(args)); + } + + /// Runs the kernel using initialized state. + Status run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return underlying_operator_.run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_base.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_base.h new file mode 100644 index 0000000000000000000000000000000000000000..5f836ecdc3a2b75c264c9ec66aa2dc023c05dc23 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_base.h @@ -0,0 +1,521 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(limits) +#else +#include +#endif + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template +class GemmUniversalBase { +public: + + using GemmKernel = GemmKernel_; + + /// Boolean indicating whether the CudaHostAdapter is enabled + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + /// Numerical accumulation element type + using ElementAccumulator = typename GemmKernel::Mma::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + + /// Index of the GEMM Kernel within the CudaHostAdapter + static int32_t const kGemmKernelIndex = 0; + + /// Kernel dynamic shared memory allocation requirement + /// Update the kernel function's shared memory configuration for the current device + static constexpr size_t kSharedStorageSize = sizeof(typename GemmKernel::SharedStorage); + +protected: + + // + // Device properties (uniform across all instances of the current thread) + // + + // Device ordinal + CUTLASS_THREAD_LOCAL static int device_ordinal_; + + /// Device SM count + CUTLASS_THREAD_LOCAL static int device_sms_; + + /// Kernel SM occupancy (in thread blocks) + CUTLASS_THREAD_LOCAL static int sm_occupancy_; + +protected: + + /// Initialize static thread-local members for the thread's current device, + /// if necessary. + static Status init_device_props() + { + CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); + + cudaError_t cudart_result; + + // Get current device ordinal + int current_ordinal; + cudart_result = cudaGetDevice(¤t_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // Done if matches the current static member + if (current_ordinal == device_ordinal_) { + // Already initialized + return Status::kSuccess; + } + + // Update SM count member + cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // If requires more than 48KB: configure for extended, dynamic shared memory + if constexpr (kSharedStorageSize >= (48 << 10)) + { + cudart_result = cudaFuncSetAttribute( + Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, + kSharedStorageSize); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + } + + // Update SM occupancy member + cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &sm_occupancy_, + Kernel2, + GemmKernel::kThreadCount, + kSharedStorageSize, + cudaOccupancyDisableCachingOverride); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // Update device ordinal member on success + device_ordinal_ = current_ordinal; + + CUTLASS_TRACE_HOST(" " + "device_ordinal: (" << device_ordinal_ << "), " + "device_sms: (" << device_sms_ << "), " + "sm_occupancy: (" << sm_occupancy_ << ") " + "smem_size: (" << kSharedStorageSize << ") " + "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); + + return Status::kSuccess; + } + + +protected: + + // + // Instance data members + // + + /// Kernel parameters + typename GemmKernel::Params params_; + + + /// Initialize params member + Status init_params(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + // + // Occupancy query using CudaHostAdapter::query_occupancy(). + // + + if (cuda_adapter) { + + Status status = cuda_adapter->query_occupancy( + &device_sms, + &sm_occupancy, + kGemmKernelIndex, + GemmKernel::kThreadCount, + kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return status; + } + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + + // Initialize static device properties, if necessary + Status result = init_device_props(); + + if (result != Status::kSuccess) { + return result; + } + + // + // Use thread-local static members for occupancy query initialized by call to + // `init_device_props()` + // + + device_sms = device_sms_; + sm_occupancy = sm_occupancy_; + } + + // Initialize params member + params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy); + return Status::kSuccess; + } + +public: + + //--------------------------------------------------------------------------------------------- + // Stateless API + //--------------------------------------------------------------------------------------------- + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); + + if (!kEnableCudaHostAdapter || cuda_adapter) { + + dim3 grid = get_grid_shape(args, cuda_adapter); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) + { + return Status::kErrorInvalidProblem; + } + } + else { + // + // With a null host adapter, a conservative grid shape is computed and required to conform to CUDA grid + // dimension limits. + // + + int64_t logicalGridM = (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + int64_t logicalGridN = (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + int32_t logicalGridL = args.batch_count; + + if ((int64_t(std::numeric_limits::max()) < logicalGridM) || + (int64_t(std::numeric_limits::max()) < logicalGridN) || + (int32_t(std::numeric_limits::max()) < logicalGridL)) { + + return Status::kErrorInvalidProblem; + } + + } + + return GemmKernel::can_implement(args); + } + + + /// Returns the workspace size (in bytes) needed for the problem + /// geometry expressed by these arguments + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { + return 0; + } + + // Get size from parameters + size_t workspace_bytes = base.params_.get_workspace_size(); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + return workspace_bytes; + } + + + /// Returns the grid extents in thread blocks to launch + static dim3 get_grid_shape(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { + return dim3(0,0,0); + } + + // Get dims from parameters + dim3 grid_dims = base.params_.get_grid_dims(); + + CUTLASS_TRACE_HOST( + " tiled_shape: " << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); + + return grid_dims; + } + + + /// Returns the maximum number of active thread blocks per multiprocessor + static int maximum_active_blocks(CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + if (cuda_adapter) { + + Status status = cuda_adapter->query_occupancy( + &device_sms, + &sm_occupancy, + kGemmKernelIndex, + GemmKernel::kThreadCount, + kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return -1; + } + } + else { + return -1; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + // Initialize static device properties, if necessary + if (init_device_props() != Status::kSuccess) { + return -1; + } + + sm_occupancy = sm_occupancy_; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); + return sm_occupancy; + } + + + //--------------------------------------------------------------------------------------------- + // Stateful API + //--------------------------------------------------------------------------------------------- + + /// Initializes GEMM state from arguments and workspace memory + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize parameters from args + Status result = init_params(args, cuda_adapter); + if (result != Status::kSuccess) { + return result; + } + + // Assign and prepare workspace memory + if (args.mode == GemmUniversalMode::kGemm) { + return params_.init_workspace(workspace, stream); + } + + return Status::kSuccess; + } + + + /// Lightweight update given a subset of arguments. + Status update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); + params_.update(args); + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); + + // Configure grid and block dimensions + dim3 block(GemmKernel::kThreadCount, 1, 1); + dim3 grid = params_.get_grid_dims(); + + // Launch kernel + CUTLASS_TRACE_HOST(" " + "grid: (" << grid << "), " + "block: (" << block << "), " + "SMEM: (" << kSharedStorageSize << ")"); + + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms_}; + return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, kernel_params, 0); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + + Kernel2<<>>(params_); + + // Query for errors + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) + { + return run(stream, cuda_adapter); + } + + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) + { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Static initializers +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device ordinal +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_ordinal_ = -1; + +/// Device SM count +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_sms_ = -1; + +/// Kernel SM occupancy (in thread blocks) +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..84d148d8418b249b98e86839b8641afd0c7c5cf9 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h @@ -0,0 +1,386 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a Stream-K GEMM kernel that can broadcast bias vector in the + epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM with a broadcast epilogue. + Supports +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC_, ElementAccumulator_, ElementAccumulator_, + ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone +> +class GemmUniversalStreamkWithBroadcast : + public GemmUniversalBase< + typename kernel::DefaultGemmStreamkWithBroadcast< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmStreamkWithBroadcast< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB> +class GemmUniversalStreamkWithBroadcast { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversalStreamkWithBroadcast< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalStreamkWithBroadcast() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_with_absmax.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..d2172d639cb95962b61eca1cad820a33afd31ab0 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_with_absmax.h @@ -0,0 +1,404 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a GEMM kernel that computes the absolute maximum of the output tensor + and applies additional scaling factors to operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_with_absmax.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Universal GEMM with absolute-maximum calculation and scaling +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm89, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC_, ElementAccumulator_, ElementAccumulator_, + ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone +> +class GemmUniversalWithAbsMax; + +// Partial specialization for SM89 +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementAccumulator_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + typename EpilogueOutputOp_, + typename ThreadblockSwizzle_, + int Stages, + int AlignmentA, + int AlignmentB, + typename Operator_, + ComplexTransform TransformA, + ComplexTransform TransformB +> +class GemmUniversalWithAbsMax< + ElementA_, + LayoutA_, + ElementB_, + LayoutB_, + ElementC_, + LayoutC_, + ElementAccumulator_, + arch::OpClassTensorOp, + arch::Sm89, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + AlignmentA, + AlignmentB, + Operator_, + TransformA, + TransformB +> : + public GemmUniversalBase< + typename kernel::DefaultGemmWithAbsMax< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + arch::OpClassTensorOp, + arch::Sm89, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm89; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmWithAbsMax< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass, + ArchTag, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for SM89 column-major output exchanges problem size and operand. +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename ElementAccumulator_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + typename EpilogueOutputOp_, + typename ThreadblockSwizzle_, + int Stages, + int AlignmentA, + int AlignmentB, + typename Operator_, + ComplexTransform TransformA, + ComplexTransform TransformB> +class GemmUniversalWithAbsMax { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm89; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversalWithAbsMax< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalWithAbsMax() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..f04bf8d5f27404a77f7851f22882832865559c63 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_universal_with_broadcast.h @@ -0,0 +1,386 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a GEMM kernel that can broadcast bias vector in the + epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM with a broadcast epilogue. + Supports +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC_, ElementAccumulator_, ElementAccumulator_, + ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone +> +class GemmUniversalWithBroadcast : + public GemmUniversalBase< + typename kernel::DefaultGemmWithBroadcast< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmWithBroadcast< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB> +class GemmUniversalWithBroadcast { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversalWithBroadcast< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalWithBroadcast() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..5bde1161c700e822c89b2d5102ac5365a02b51e4 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemm_with_k_reduction.h @@ -0,0 +1,415 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a GEMM kernel that can reduce one of the input matrix + into a vector along the K dimension. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_with_k_reduction.h" + +#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Reduce A or B operand along the K dimension + bool ReduceKForA_ = true, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute +> +class GemmWithKReduction : + public GemmUniversalBase< + typename kernel::DefaultGemmWithKReduction< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ReduceKForA_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static constexpr int kStages = Stages; + static constexpr int kAlignmentA = AlignmentA; + static constexpr int kAlignmentB = AlignmentB; + static constexpr int kAlignmentC = EpilogueOutputOp::kCount; + static constexpr ComplexTransform kTransformA = TransformA; + static constexpr ComplexTransform kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmWithKReduction< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ReduceKForA_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Reduce A or B operand along the K dimension + bool ReduceKForA_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout +> +class GemmWithKReduction { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmWithKReduction< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + !ReduceKForA_, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA, + GatherB, + GatherA, + ScatterD, + PermuteDLayout + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmWithKReduction() = default; + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemv.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..763f18e8ec04b445220000dd63098792c4a8e48d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemv.h @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Gemv { +public: + + using GemvKernel = GemvKernel_; + + + using ElementA = typename GemvKernel::ElementA; + using LayoutA = typename GemvKernel::LayoutA; + using ElementB = typename GemvKernel::ElementB; + using ElementC = typename GemvKernel::ElementC; + + using ElementAccumulator = typename GemvKernel::ElementAccumulator; + using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; + + static ComplexTransform const kTransformA = GemvKernel::kTransformA; + static ComplexTransform const kTransformB = GemvKernel::kTransformB; + + static int const kThreadCount = GemvKernel::kThreadCount; + static int const kThreadsPerRow = GemvKernel::kThreadsPerRow; + + using Arguments = typename GemvKernel::Arguments; + using Params = typename GemvKernel::Params; + +private: + + Params params_; + +public: + + /// Constructs the Gemv. + Gemv() { } + + /// Determines whether the Gemv can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemvKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return 0; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args, dim3 const &block) { + if(platform::is_same::value) { + return dim3((args.problem_size.row() + (block.x - 1)) / block.x, 1, args.batch_count % 65536); + } + else { + return dim3((args.problem_size.row() + (block.y - 1)) / block.y, 1, args.batch_count % 65536); + } + } + + /// Computes the block shape + static dim3 get_block_shape() { + if(platform::is_same::value) { + return dim3(kThreadCount, 1, 1); + } + else { + return dim3(kThreadsPerRow, kThreadCount / kThreadsPerRow, 1); + } + } + + /// Initializes Gemv state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + params_ = Params(args); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + return params_.update(args); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + dim3 block = get_block_shape(); + dim3 grid = get_grid_shape(params_, block); + + int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); + + // Launch + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemv_blockscaled.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemv_blockscaled.h new file mode 100644 index 0000000000000000000000000000000000000000..b4dc0dd3061c9dc00e184881689ed0bb74e1921b --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/gemv_blockscaled.h @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemvBlockScaled { +public: + + using GemvKernel = GemvKernel_; + + + using ElementA = typename GemvKernel::ElementA; + using LayoutA = typename GemvKernel::LayoutA; + using ElementB = typename GemvKernel::ElementB; + using ElementC = typename GemvKernel::ElementC; + + using ElementSFA = typename GemvKernel::ElementSFA; + using ElementSFB = typename GemvKernel::ElementSFB; + + using ElementAccumulator = typename GemvKernel::ElementAccumulator; + using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; + + static ComplexTransform const kTransformA = GemvKernel::kTransformA; + static ComplexTransform const kTransformB = GemvKernel::kTransformB; + + static int const kThreadCount = GemvKernel::kThreadCount; + static int const kThreadsPerRow = GemvKernel::kThreadsPerRow; + + using Arguments = typename GemvKernel::Arguments; + using Params = typename GemvKernel::Params; + +private: + + Params params_; + +public: + + /// Constructs the GemvBlockScaled. + GemvBlockScaled() = default; + + /// Determines whether the GemvBlockScaled can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemvKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return 0; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args, dim3 const &block) { + if(platform::is_same::value) { + return dim3((args.problem_size.row() + (block.x - 1)) / block.x, 1, args.batch_count % 65536); + } + else { + return dim3((args.problem_size.row() + (block.y - 1)) / block.y, 1, args.batch_count % 65536); + } + } + + /// Computes the block shape + static dim3 get_block_shape() { + if(platform::is_same::value) { + return dim3(kThreadCount, 1, 1); + } + else { + return dim3(kThreadsPerRow, kThreadCount / kThreadsPerRow, 1); + } + } + + /// Initializes GemvBlockScaled state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + params_ = Params(args); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + return params_.update(args); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + const dim3 block = get_block_shape(); + const dim3 grid = get_grid_shape(params_, block); + + int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + if (result == cudaSuccess) { + return Status::kSuccess; + } else { + return Status::kErrorInternal; + } + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_2k.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..293ca06a3a943ef83ca63bf6e6cc545e052c0a1a --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_2k.h @@ -0,0 +1,548 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Rank2K kernel. Does not compute batching or support split-K. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/rank_2k_universal.h" + +#include "cutlass/gemm/kernel/default_rank_2k_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by SYRK + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation + ComplexTransform TransformB = ComplexTransform::kNone, + /// Blas3 computation mode (symmetric/hermitian) + BlasMode BlasMode_ = BlasMode::kSymmetric> +class Rank2K { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static FillMode const kFillModeC = FillModeC; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + static BlasMode const kBlasMode = BlasMode_; + static int const kUpdateRank = 2; + + // static asserts for rank 2k update kernel + static_assert(platform::is_same::value, + "Rank 2K update operator support same layouts for operandA and B"); + + /// Define the kernel + using Rank2Kkernel = typename kernel::DefaultRank2KUniversal< + ElementA, + LayoutA, + kTransformA, + kAlignmentA, + ElementB, + LayoutB, + kTransformB, + kAlignmentB, + ElementC, + LayoutC, + kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kBlasMode + >::Rank2Kkernel; + + using Arguments = typename Rank2Kkernel::Arguments; + +private: + + /// Kernel parameters object + typename Rank2Kkernel::Params params_; +public: + + /// Constructs the SYRK. + Rank2K() { } + + /// Determines whether the SYRK can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = Rank2Kkernel::can_implement(args); + + if (FillModeC != FillMode::kLower && FillModeC != FillMode::kUpper) { + return Status::kErrorInvalidProblem; + } + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial && args.batch_count > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes SYRK state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial) { + if (args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + } + + int gemm_k_size = args.problem_size.k(); + + // Initialize the Params structure + params_ = typename Rank2Kkernel::Params{ + args, + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + }; + + int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(Rank2Kkernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename Rank2Kkernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchange operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial, + /// Operation performed by Rank2K update kernel + typename Operator_, + /// Complex elementwise transformation + ComplexTransform TransformA, + /// Complex elementwise transformation + ComplexTransform TransformB, + /// Blas3 computation mode (symmetric/hermitian) + BlasMode BlasMode_ + > +class Rank2K { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static FillMode const kFillModeC = FillModeC; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static BlasMode const kBlasMode = BlasMode_; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + static int const kUpdateRank = 2; + + /// Define the kernel + using UnderlyingOperator = typename cutlass::gemm::device::Rank2K< + ElementB, + LayoutB, + ElementA, + LayoutA, + ElementC, + layout::RowMajor, + InvertFillMode::mode, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kAlignmentB, + kAlignmentA, + kSplitKSerial, + Operator, + kTransformA, + kTransformB, + kBlasMode + >; + + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + using Rank2Kkernel = typename UnderlyingOperator::Rank2Kkernel; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the Rank2K. + Rank2K() { } + + /// Helper to construct a transposed equivalent for the underlying Rank2K operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the Rank2K can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes Rank2K state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace Rank2K +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..0c59744b5a9b6c7e98aa66a7b8ddb998413ed46e --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_2k_grouped.h @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Device-level grouped Rank2K. +*/ + +#pragma once + +#include "cutlass/gemm/device/base_grouped.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Rank2K Grouped +template +class Rank2KGrouped : public BaseGrouped { +public: + using Rank2Kkernel = Rank2Kkernel_; + static const cutlass::FillMode kFillModeC = Rank2Kkernel::kFillModeC; + static const cutlass::BlasMode kBlasMode = Rank2Kkernel::kBlasMode; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_k.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_k.h new file mode 100644 index 0000000000000000000000000000000000000000..80c420cd8a73859183a013fbd1b10ca0f46cbc0d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/rank_k.h @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined RankK kernel. Does not compute batching or support split-K. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/rank_k_universal.h" + +#include "cutlass/gemm/kernel/default_rank_k_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by SYRK + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementA_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation + ComplexTransform TransformA = ComplexTransform::kNone, + /// Blas3 computation mode (symmetric/hermitian) + BlasMode BlasMode_ = BlasMode::kSymmetric> +class RankK { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static FillMode const kFillModeC = FillModeC; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = TransformA; + static BlasMode const kBlasMode = BlasMode_; + static int const kUpdateRank = 1; + + /// Define the kernel + using RankKkernel = typename kernel::DefaultRankKUniversal< + ElementA, + LayoutA, + kTransformA, + kAlignmentA, + ElementC, + LayoutC, + kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kBlasMode + >::RankKkernel; + + using Arguments = typename RankKkernel::Arguments; + +private: + + /// Kernel parameters object + typename RankKkernel::Params params_; +public: + + /// Constructs the SYRK. + RankK() { } + + /// Determines whether the SYRK can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = RankKkernel::can_implement(args); + + if (FillModeC != FillMode::kLower && FillModeC != FillMode::kUpper) { + return Status::kErrorInvalidProblem; + } + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial && args.batch_count > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes SYRK state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial) { + if (args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + } + + int gemm_k_size = args.problem_size.k(); + + // Initialize the Params structure + params_ = typename RankKkernel::Params{ + args, + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + }; + + int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(RankKkernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename RankKkernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchange operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial, + /// Operation performed by RankK update kernel + typename Operator_, + /// Complex elementwise transformation + ComplexTransform TransformA, + /// Blas3 computation mode (symmetric/hermitian) + BlasMode BlasMode_ + > +class RankK { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static FillMode const kFillModeC = FillModeC; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static BlasMode const kBlasMode = BlasMode_; + static int const kUpdateRank = 1; + + // Complex transform for input A matrices (function on input layout) + static ComplexTransform const kTransformA = TransformA; + + /// Define the kernel + using UnderlyingOperator = typename cutlass::gemm::device::RankK< + ElementA, + LayoutA, + ElementC, + layout::RowMajor, + InvertFillMode::mode, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kAlignmentA, + kSplitKSerial, + Operator, + kTransformA, + kBlasMode + >; + + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + using RankKkernel = typename UnderlyingOperator::RankKkernel; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the RankK. + RankK() { } + + /// Helper to construct a transposed equivalent for the underlying RankK operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args; + } + + /// Determines whether the RankK can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes RankK state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace RankK +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/symm.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/symm.h new file mode 100644 index 0000000000000000000000000000000000000000..538d294f83e24955c2354cfaceeb79e835fc28cd --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/symm.h @@ -0,0 +1,603 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined SYMM and HEMM kernels. Does not compute batching or support split-K. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/symm_universal.h" + +#include "cutlass/gemm/kernel/default_symm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Side Mode for A (kLeft or kRight) + SideMode SideModeA, + /// Fill Mode for A (kLower or kUpper) + FillMode FillModeA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = epilogue::thread::LinearCombination< + ElementC_, + 128 / sizeof_bits::value, + ElementAccumulator_, + ElementAccumulator_, + epilogue::thread::ScaleType::OnlyAlphaScaling + >, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by SYMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Blas3 computation mode (symmetric/hermitian) + BlasMode BlasMode_ = BlasMode::kSymmetric> +class Symm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementB_, ElementA_>::type; + using LayoutAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutB_, LayoutA_>::type; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementA_, ElementB_>::type; + using LayoutBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutA_, LayoutB_>::type; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentAKernel = (SideModeA == SideMode::kRight) ? AlignmentB : AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static BlasMode const kBlasMode = BlasMode_; + + // static asserts for symm update kernel + static_assert(platform::is_same::value, + "SYMM update operator support same layouts for operand A and B"); + + /// Define the kernel + using SymmKernel = typename kernel::DefaultSymmUniversal< + ElementAKernel, + LayoutAKernel, + kSideModeA, + kFillModeA, + kAlignmentAKernel, + ElementBKernel, + LayoutBKernel, + kAlignmentBKernel, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kBlasMode + >::SymmKernel; + + using Arguments = typename SymmKernel::Arguments; + +private: + + /// Kernel parameters object + typename SymmKernel::Params params_; +public: + + /// Constructs the SYMM. + Symm() { } + + /// Determines whether the SYMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = SymmKernel::can_implement(args); + + if (SideModeA == SideMode::kInvalid) { + return Status::kErrorInvalidProblem; + } + + if (FillModeA != FillMode::kLower && FillModeA != FillMode::kUpper) { + return Status::kErrorInvalidProblem; + } + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial && args.batch_count > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes SYMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial) { + if (args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + } + + int gemm_k_size = args.problem_size.k(); + + // Swapping argument for A and B, if A was on the right side (problem size doesn't need to change here). + if (kSideModeA == SideMode::kRight) { + // Initialize the Params structure + params_ = typename SymmKernel::Params{ + args.swapped_matrices(), + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + }; + + return Status::kSuccess; + } + + // Initialize the Params structure + params_ = typename SymmKernel::Params{ + args, + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(SymmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename SymmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +/******************************************************************************************************** + SYMM/HEMM has 4 combinations based on Layouts {RowMajor, ColumnMajor} x Side mode {LeftSide, RightSide} + In templates and arguments to cutlass kernel, `matrix A` is always symmetric/hermitian, and `matrix B` is rectangular. + (adhering to the cuBLAS convention) + + Although, cuBLAS SYMM/HEMM only supports ColumnMajor layouts for all matrices (A, B, C/D). + + For the mainloop and symm kernel, `A` and `B` points to left-side and right-side matrices, respectively. + + Thus, for LeftSide mode `A` and `B` points to `matrix A` and `matrix B`, respectively. While for + the RightSide mode `A` and `B` points to `matrix B` and `matrix A`, respectively. + + Additionally, CUTLASS GEMM epilogue is always RowMajor, and ColumnMajor output is achieved by + transposing the GEMM problem. Thus, ColumnMajor output layout for SYMM/HEMM requires: + - Transposing `matrix A` and `matrix B` layouts + - Swapping problem size m and n values + - Swapping LeftSide and RightSide mode + + RowMajor output: D = matrix A x matrix B + ColumnMajor output: D = matrix A x matrix B -> Transpose (D) = Transpose(matrix B) x Transpose(matrix A) + + {RowMajor, ColumnMajor} x Side Mode {LeftSide, RightSide} 4 cases: + 1. LeftSide mode and RowMajor output (default template) + 2. LeftSide mode and ColumnMajor output + 3. RightSide mode and RowMajor output + 4. RightSide mode and ColumnMajor output + + Mapping ColumnMajor output layout cases 2 and 4 to RowMajor efficient epilogue implementation: + + Case 2 -> Case 3: + D_col = matrix A x matrix B (LeftSide mode) + => Transpose(D_col) = Transpose(matrix B) x Transpose(matrix A) (RightSide mode) + + swap pointers for `A` and `B` call GEMM mainloop with RowMajor efficient-epilogue + + Case 4 -> Case 1: + D_col = matrix B x matrix A (RightSide mode) + => Transpose(D_col) = Transpose(matrix A) x Transpose(matrix B) (LeftSide mode) + + call GEMM mainloop for with RowMajor efficient-epilogue +********************************************************************************************************/ + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Side Mode for A (kLeft or kRight) + SideMode SideModeA, + /// Fill Mode for A (kLower or kUpper) + FillMode FillModeA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial, + /// Operation performed by Symm update kernel + typename Operator_, + /// Blas3 computation mode (symmetric/hermitian) + BlasMode BlasMode_ + > +class Symm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static BlasMode const kBlasMode = BlasMode_; + + /// Define the kernel + using UnderlyingOperator = typename cutlass::gemm::device::Symm< + ElementA, + typename layout::LayoutTranspose::type, + InvertSideMode::mode, + InvertFillMode::mode, + ElementB, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kAlignmentA, + kAlignmentB, + kSplitKSerial, + Operator, + kBlasMode + >; + + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + using SymmKernel = typename UnderlyingOperator::SymmKernel; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the Symm. + Symm() { } + + /// Helper to construct a transposed equivalent for the underlying SYMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem_size(); + } + + /// Determines whether the Symm can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes Symm state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace Symm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/trmm.h b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/trmm.h new file mode 100644 index 0000000000000000000000000000000000000000..46f6473e8a201a22ee3f4b55783f0a5d24b91d54 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/device/trmm.h @@ -0,0 +1,759 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a TRMM kernel. Does not compute batching or support split-K. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/trmm_universal.h" + +#include "cutlass/gemm/kernel/default_trmm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Trmm device-level operator. This is an interface to efficient CUTLASS TRMM kernels that may + be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to TRMM problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most plausible TRMM + configurations for each supported architecture. Consequently, not all parameters are exposed + to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy + are selected to tradeoff simplicity of the interface with flexibility. We expect + most configurations to be specified at this level. Applications with more exotic requirements + may construct their kernels of interest using CUTLASS components at the threadblock, warp, + and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects compose some + internal state with an overloaded function call operator. This enables decoupling of + initialization from execution, possibly reducing overhead during steady state phases of + application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each logical + input to the computation. This is distinct from the kernel-level Params structure pattern + which contains application-specific precomputed state needed by the device code. + + Example of a CUTLASS TRMM operator implementing the functionality of cuBLAS's STRMM NN + is as follows: + + // + // Instantiate the CUTLASS TRMM operator. + // + + cutlass::gemm::device::Trmm< + float, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kLeft, + cutlass::FillMode::kLower, + cutlass::DiagType::kNonUnit, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + > trmm_op; + + // + // Launch the TRMM operation on the device + // + + cutlass::Status status = trmm_op({ + cutlass::gemm::GemmUniversalMode, // Trmm Problem Mode + {m, n, m/n}, // GemmCoord problem_size (k is based on left- or right-side mode) + batch_count, + {alpha}, // EpilogueOutputOp::Params epilogue_op_params + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int lda, + int ldb, + int ldc + }); + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Side Mode for A (kLeft or kRight) + SideMode SideModeA, + + /// Fill Mode for A (kLower or kUpper) + FillMode FillModeA, + + /// DiagType for A (kNonUnit or kUnit) + DiagType DiagTypeA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages, + + /// Access granularity of A matrix in units of elements + int AlignmentA, + + /// Access granularity of B matrix in units of elements + int AlignmentB, + + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial, + + /// Operation performed by TRMM + typename Operator, + + /// Complex elementwise transformation on A operand + ComplexTransform TransformA + > + class Trmm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Side Mode for A + SideMode SideModeA, + /// Fill Mode for A + FillMode FillModeA, + /// DiagType for A + DiagType DiagTypeA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = epilogue::thread::LinearCombination< + ElementC_, + 128 / sizeof_bits::value, + ElementAccumulator_, + ElementAccumulator_, + epilogue::thread::ScaleType::OnlyAlphaScaling + >, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by TRMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone> +class Trmm { + public: + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementB_, ElementA_>::type; + using LayoutAKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutB_, LayoutA_>::type; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), ElementA_, ElementB_>::type; + using LayoutBKernel = typename platform::conditional<(SideModeA == SideMode::kRight), LayoutA_, LayoutB_>::type; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static SideMode const kSideMode = SideModeA; + static FillMode const kFillMode = FillModeA; + static DiagType const kDiagType = DiagTypeA; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentAKernel = (SideModeA == SideMode::kRight) ? AlignmentB : AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + // Complex Transform don't apply to B + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static ComplexTransform const kTransformAKernel = (SideModeA == SideMode::kRight) ? + ComplexTransform::kNone : TransformA; + static ComplexTransform const kTransformBKernel = (SideModeA == SideMode::kRight) ? + TransformA : ComplexTransform::kNone; + + /// Define the kernel + using TrmmKernel = typename kernel::DefaultTrmmUniversal< + ElementAKernel, + LayoutAKernel, + kTransformAKernel, + kAlignmentAKernel, + ElementBKernel, + LayoutBKernel, + kTransformBKernel, + kAlignmentBKernel, + kSideMode, + kFillMode, + kDiagType, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator + >::TrmmKernel; + + using Arguments = typename TrmmKernel::Arguments; + +private: + + /// Kernel parameters object + typename TrmmKernel::Params params_; +public: + + /// Constructs the TRMM. + Trmm() { } + + /// Determines whether the TRMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = TrmmKernel::can_implement(args); + + if (SideModeA == SideMode::kInvalid) { + return Status::kErrorInvalidProblem; + } + + if (FillModeA == FillMode::kInvalid) { + return Status::kErrorInvalidProblem; + } + + if (DiagTypeA == DiagType::kInvalid) { + return Status::kErrorInvalidProblem; + } + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial && args.batch_count > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes TRMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (kSplitKSerial) { + if (args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.batch_count > 1) { + return Status::kErrorInvalidProblem; + } + } + + int gemm_k_size = args.problem_size.k(); + + // Swapping argument for A and B, if A was on the right side (problem size doesn't need to change here). + if (kSideMode == SideMode::kRight) { + // Initialize the Params structure + params_ = typename TrmmKernel::Params{ + args.swapped_matrices(), + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + }; + + return Status::kSuccess; + } + + // Initialize the Params structure + params_ = typename TrmmKernel::Params{ + args, + grid_tiled_shape, + gemm_k_size, + static_cast(workspace) + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.batch_count > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(TrmmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename TrmmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +/******************************************************************************************************** + TRMM has 4 combinations based on Layouts {RowMajor, ColumnMajor} x Side mode {LeftSide, RightSide} + In templates and arguments to cutlass kernel, `matrix A` is always triangular, and `matrix B` is rectangular. + (adhering to the cuBLAS convention) + +For the mainloop and trmm kernel, `A` and `B` points to left-side and right-side matrices, respectively. + + Thus, for LeftSide mode `A` and `B` points to `matrix A` and `matrix B`, respectively. While for + the RightSide mode `A` and `B` points to `matrix B` and `matrix A`, respectively. + + Additionally, CUTLASS GEMM epilogue is always RowMajor, and ColumnMajor output is achieved by + transposing the GEMM problem. Thus, ColumnMajor output layout for TRMM requires: + - Transposing `matrix A` and `matrix B` layouts + - Swapping problem size m and n values + - Swapping LeftSide and RightSide mode + + RowMajor output: D = matrix A x matrix B + ColumnMajor output: D = matrix A x matrix B -> Transpose (D) = Transpose(matrix B) x Transpose(matrix A) + + {RowMajor, ColumnMajor} x Side Mode {LeftSide, RightSide} 4 cases: + 1. LeftSide mode and RowMajor output (default template) + 2. LeftSide mode and ColumnMajor output + 3. RightSide mode and RowMajor output + 4. RightSide mode and ColumnMajor output + + Mapping ColumnMajor output layout cases 2 and 4 to RowMajor efficient epilogue implementation: + + Case 2 -> Case 3: + D_col = matrix A x matrix B (LeftSide mode) + => Transpose(D_col) = Transpose(matrix B) x Transpose(matrix A) (RightSide mode) + + swap pointers for `A` and `B` call GEMM mainloop with RowMajor efficient-epilogue + + Case 4 -> Case 1: + D_col = matrix B x matrix A (RightSide mode) + => Transpose(D_col) = Transpose(matrix A) x Transpose(matrix B) (LeftSide mode) + + call GEMM mainloop for with RowMajor efficient-epilogue +********************************************************************************************************/ + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Side Mode for A + SideMode SideModeA, + /// Fill Mode for A + FillMode FillModeA, + /// DiagType for A + DiagType DiagTypeA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K as a serial reduction + bool SplitKSerial, + /// Operation performed by TRMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA> +class Trmm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static SideMode const kSideMode = SideModeA; + static FillMode const kFillMode = FillModeA; + static DiagType const kDiagType = DiagTypeA; + // Changing SideMode as we change the layout + static SideMode const kSideModeT = (SideModeA == SideMode::kLeft) ? + SideMode::kRight : SideMode::kLeft; + // Changing FillMode as we change the layout + static FillMode const kFillModeT = (FillModeA == FillMode::kLower) ? + FillMode::kUpper : FillMode::kLower; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + // Complex Transform don't apply to B + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kSplitKSerial = SplitKSerial; + + using UnderlyingOperator = Trmm< + ElementA, + typename layout::LayoutTranspose::type, + kSideModeT, + kFillModeT, + kDiagType, + ElementB, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kAlignmentA, + kAlignmentB, + kSplitKSerial, + Operator, + TransformA + >; + + using Arguments = typename UnderlyingOperator::Arguments; + using TrmmKernel = typename UnderlyingOperator::TrmmKernel; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the TRMM. + Trmm() { } + + /// Helper to construct a transposed equivalent for the underlying TRMM operator which is identical + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem_size(); + } + + /// Determines whether the TRMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Initializes TRMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/dispatch_policy.hpp b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/dispatch_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f42fc7ba89f7c4325634119e334a37d4ca340e5 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/gemm/dispatch_policy.hpp @@ -0,0 +1,1430 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/layout.hpp" +#include "cute/numeric/integral_constant.hpp" // cute::false_type +#include "cute/atom/copy_traits_sm100.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +template class U> +struct is_kernel_tag_of : cute::false_type {}; + +template